1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/ipc/client_impl.h"
18 
19 #include <fcntl.h>
20 
21 #include <cinttypes>
22 #include <utility>
23 
24 #include "perfetto/base/task_runner.h"
25 #include "perfetto/ext/base/unix_socket.h"
26 #include "perfetto/ext/base/utils.h"
27 #include "perfetto/ext/ipc/service_descriptor.h"
28 #include "perfetto/ext/ipc/service_proxy.h"
29 
30 #include "protos/perfetto/ipc/wire_protocol.gen.h"
31 
32 // TODO(primiano): Add ThreadChecker everywhere.
33 
34 // TODO(primiano): Add timeouts.
35 
36 namespace perfetto {
37 namespace ipc {
38 
39 namespace {
40 constexpr base::SockFamily kClientSockFamily =
41     kUseTCPSocket ? base::SockFamily::kInet : base::SockFamily::kUnix;
42 }  // namespace
43 
44 // static
CreateInstance(ConnArgs conn_args,base::TaskRunner * task_runner)45 std::unique_ptr<Client> Client::CreateInstance(ConnArgs conn_args,
46                                                base::TaskRunner* task_runner) {
47   std::unique_ptr<Client> client(
48       new ClientImpl(std::move(conn_args), task_runner));
49   return client;
50 }
51 
ClientImpl(ConnArgs conn_args,base::TaskRunner * task_runner)52 ClientImpl::ClientImpl(ConnArgs conn_args, base::TaskRunner* task_runner)
53     : socket_name_(conn_args.socket_name),
54       socket_retry_(conn_args.retry),
55       task_runner_(task_runner),
56       weak_ptr_factory_(this) {
57   if (conn_args.socket_fd) {
58     // Create the client using a connected socket. This code path will never hit
59     // OnConnect().
60     sock_ = base::UnixSocket::AdoptConnected(
61         std::move(conn_args.socket_fd), this, task_runner_, kClientSockFamily,
62         base::SockType::kStream, base::SockPeerCredMode::kIgnore);
63   } else {
64     // Connect using the socket name.
65     TryConnect();
66   }
67 }
68 
~ClientImpl()69 ClientImpl::~ClientImpl() {
70   // Ensure we are not destroyed in the middle of invoking a reply.
71   PERFETTO_DCHECK(!invoking_method_reply_);
72   OnDisconnect(
73       nullptr);  // The base::UnixSocket* ptr is not used in OnDisconnect().
74 }
75 
TryConnect()76 void ClientImpl::TryConnect() {
77   PERFETTO_DCHECK(socket_name_);
78   sock_ = base::UnixSocket::Connect(socket_name_, this, task_runner_,
79                                     kClientSockFamily, base::SockType::kStream,
80                                     base::SockPeerCredMode::kIgnore);
81 }
82 
BindService(base::WeakPtr<ServiceProxy> service_proxy)83 void ClientImpl::BindService(base::WeakPtr<ServiceProxy> service_proxy) {
84   if (!service_proxy)
85     return;
86   if (!sock_->is_connected()) {
87     queued_bindings_.emplace_back(service_proxy);
88     return;
89   }
90   RequestID request_id = ++last_request_id_;
91   Frame frame;
92   frame.set_request_id(request_id);
93   Frame::BindService* req = frame.mutable_msg_bind_service();
94   const char* const service_name = service_proxy->GetDescriptor().service_name;
95   req->set_service_name(service_name);
96   if (!SendFrame(frame)) {
97     PERFETTO_DLOG("BindService(%s) failed", service_name);
98     return service_proxy->OnConnect(false /* success */);
99   }
100   QueuedRequest qr;
101   qr.type = Frame::kMsgBindServiceFieldNumber;
102   qr.request_id = request_id;
103   qr.service_proxy = service_proxy;
104   queued_requests_.emplace(request_id, std::move(qr));
105 }
106 
UnbindService(ServiceID service_id)107 void ClientImpl::UnbindService(ServiceID service_id) {
108   service_bindings_.erase(service_id);
109 }
110 
BeginInvoke(ServiceID service_id,const std::string & method_name,MethodID remote_method_id,const ProtoMessage & method_args,bool drop_reply,base::WeakPtr<ServiceProxy> service_proxy,int fd)111 RequestID ClientImpl::BeginInvoke(ServiceID service_id,
112                                   const std::string& method_name,
113                                   MethodID remote_method_id,
114                                   const ProtoMessage& method_args,
115                                   bool drop_reply,
116                                   base::WeakPtr<ServiceProxy> service_proxy,
117                                   int fd) {
118   RequestID request_id = ++last_request_id_;
119   Frame frame;
120   frame.set_request_id(request_id);
121   Frame::InvokeMethod* req = frame.mutable_msg_invoke_method();
122   req->set_service_id(service_id);
123   req->set_method_id(remote_method_id);
124   req->set_drop_reply(drop_reply);
125   req->set_args_proto(method_args.SerializeAsString());
126   if (!SendFrame(frame, fd)) {
127     PERFETTO_DLOG("BeginInvoke() failed while sending the frame");
128     return 0;
129   }
130   if (drop_reply)
131     return 0;
132   QueuedRequest qr;
133   qr.type = Frame::kMsgInvokeMethodFieldNumber;
134   qr.request_id = request_id;
135   qr.method_name = method_name;
136   qr.service_proxy = std::move(service_proxy);
137   queued_requests_.emplace(request_id, std::move(qr));
138   return request_id;
139 }
140 
SendFrame(const Frame & frame,int fd)141 bool ClientImpl::SendFrame(const Frame& frame, int fd) {
142   // Serialize the frame into protobuf, add the size header, and send it.
143   std::string buf = BufferedFrameDeserializer::Serialize(frame);
144 
145   // TODO(primiano): this should do non-blocking I/O. But then what if the
146   // socket buffer is full? We might want to either drop the request or throttle
147   // the send and PostTask the reply later? Right now we are making Send()
148   // blocking as a workaround. Propagate bakpressure to the caller instead.
149   bool res = sock_->Send(buf.data(), buf.size(), fd);
150   PERFETTO_CHECK(res || !sock_->is_connected());
151   return res;
152 }
153 
OnConnect(base::UnixSocket *,bool connected)154 void ClientImpl::OnConnect(base::UnixSocket*, bool connected) {
155   if (!connected && socket_retry_) {
156     socket_backoff_ms_ =
157         (socket_backoff_ms_ < 10000) ? socket_backoff_ms_ + 1000 : 30000;
158     PERFETTO_DLOG(
159         "Connection to traced's UNIX socket failed, retrying in %u seconds",
160         socket_backoff_ms_ / 1000);
161     auto weak_this = weak_ptr_factory_.GetWeakPtr();
162     task_runner_->PostDelayedTask(
163         [weak_this] {
164           if (weak_this)
165             static_cast<ClientImpl&>(*weak_this).TryConnect();
166         },
167         socket_backoff_ms_);
168     return;
169   }
170 
171   // Drain the BindService() calls that were queued before establishing the
172   // connection with the host. Note that if we got disconnected, the call to
173   // OnConnect below might delete |this|, so move everything on the stack first.
174   auto queued_bindings = std::move(queued_bindings_);
175   queued_bindings_.clear();
176   for (base::WeakPtr<ServiceProxy>& service_proxy : queued_bindings) {
177     if (connected) {
178       BindService(service_proxy);
179     } else if (service_proxy) {
180       service_proxy->OnConnect(false /* success */);
181     }
182   }
183   // Don't access |this| below here.
184 }
185 
OnDisconnect(base::UnixSocket *)186 void ClientImpl::OnDisconnect(base::UnixSocket*) {
187   for (const auto& it : service_bindings_) {
188     base::WeakPtr<ServiceProxy> service_proxy = it.second;
189     task_runner_->PostTask([service_proxy] {
190       if (service_proxy)
191         service_proxy->OnDisconnect();
192     });
193   }
194   for (const auto& it : queued_requests_) {
195     const QueuedRequest& queued_request = it.second;
196     if (queued_request.type != Frame::kMsgBindServiceFieldNumber) {
197       continue;
198     }
199     base::WeakPtr<ServiceProxy> service_proxy = queued_request.service_proxy;
200     task_runner_->PostTask([service_proxy] {
201       if (service_proxy)
202         service_proxy->OnConnect(false);
203     });
204   }
205   service_bindings_.clear();
206   queued_bindings_.clear();
207 }
208 
OnDataAvailable(base::UnixSocket *)209 void ClientImpl::OnDataAvailable(base::UnixSocket*) {
210   size_t rsize;
211   do {
212     auto buf = frame_deserializer_.BeginReceive();
213     base::ScopedFile fd;
214     rsize = sock_->Receive(buf.data, buf.size, &fd);
215 #if PERFETTO_BUILDFLAG(PERFETTO_OS_WIN)
216     PERFETTO_DCHECK(!fd);
217 #else
218     if (fd) {
219       PERFETTO_DCHECK(!received_fd_);
220       int res = fcntl(*fd, F_SETFD, FD_CLOEXEC);
221       PERFETTO_DCHECK(res == 0);
222       received_fd_ = std::move(fd);
223     }
224 #endif
225     if (!frame_deserializer_.EndReceive(rsize)) {
226       // The endpoint tried to send a frame that is way too large.
227       return sock_->Shutdown(true);  // In turn will trigger an OnDisconnect().
228       // TODO(fmayer): check this.
229     }
230   } while (rsize > 0);
231 
232   while (std::unique_ptr<Frame> frame = frame_deserializer_.PopNextFrame())
233     OnFrameReceived(*frame);
234 }
235 
OnFrameReceived(const Frame & frame)236 void ClientImpl::OnFrameReceived(const Frame& frame) {
237   auto queued_requests_it = queued_requests_.find(frame.request_id());
238   if (queued_requests_it == queued_requests_.end()) {
239     PERFETTO_DLOG("OnFrameReceived(): got invalid request_id=%" PRIu64,
240                   static_cast<uint64_t>(frame.request_id()));
241     return;
242   }
243   QueuedRequest req = std::move(queued_requests_it->second);
244   queued_requests_.erase(queued_requests_it);
245 
246   if (req.type == Frame::kMsgBindServiceFieldNumber &&
247       frame.has_msg_bind_service_reply()) {
248     return OnBindServiceReply(std::move(req), frame.msg_bind_service_reply());
249   }
250   if (req.type == Frame::kMsgInvokeMethodFieldNumber &&
251       frame.has_msg_invoke_method_reply()) {
252     return OnInvokeMethodReply(std::move(req), frame.msg_invoke_method_reply());
253   }
254   if (frame.has_msg_request_error()) {
255     PERFETTO_DLOG("Host error: %s", frame.msg_request_error().error().c_str());
256     return;
257   }
258 
259   PERFETTO_DLOG(
260       "OnFrameReceived() request type=%d, received unknown frame in reply to "
261       "request_id=%" PRIu64,
262       req.type, static_cast<uint64_t>(frame.request_id()));
263 }
264 
OnBindServiceReply(QueuedRequest req,const Frame::BindServiceReply & reply)265 void ClientImpl::OnBindServiceReply(QueuedRequest req,
266                                     const Frame::BindServiceReply& reply) {
267   base::WeakPtr<ServiceProxy>& service_proxy = req.service_proxy;
268   if (!service_proxy)
269     return;
270   const char* svc_name = service_proxy->GetDescriptor().service_name;
271   if (!reply.success()) {
272     PERFETTO_DLOG("BindService(): unknown service_name=\"%s\"", svc_name);
273     return service_proxy->OnConnect(false /* success */);
274   }
275 
276   auto prev_service = service_bindings_.find(reply.service_id());
277   if (prev_service != service_bindings_.end() && prev_service->second.get()) {
278     PERFETTO_DLOG(
279         "BindService(): Trying to bind service \"%s\" but another service "
280         "named \"%s\" is already bound with the same ID.",
281         svc_name, prev_service->second->GetDescriptor().service_name);
282     return service_proxy->OnConnect(false /* success */);
283   }
284 
285   // Build the method [name] -> [remote_id] map.
286   std::map<std::string, MethodID> methods;
287   for (const auto& method : reply.methods()) {
288     if (method.name().empty() || method.id() <= 0) {
289       PERFETTO_DLOG("OnBindServiceReply(): invalid method \"%s\" -> %" PRIu64,
290                     method.name().c_str(), static_cast<uint64_t>(method.id()));
291       continue;
292     }
293     methods[method.name()] = method.id();
294   }
295   service_proxy->InitializeBinding(weak_ptr_factory_.GetWeakPtr(),
296                                    reply.service_id(), std::move(methods));
297   service_bindings_[reply.service_id()] = service_proxy;
298   service_proxy->OnConnect(true /* success */);
299 }
300 
OnInvokeMethodReply(QueuedRequest req,const Frame::InvokeMethodReply & reply)301 void ClientImpl::OnInvokeMethodReply(QueuedRequest req,
302                                      const Frame::InvokeMethodReply& reply) {
303   base::WeakPtr<ServiceProxy> service_proxy = req.service_proxy;
304   if (!service_proxy)
305     return;
306   std::unique_ptr<ProtoMessage> decoded_reply;
307   if (reply.success()) {
308     // If this becomes a hotspot, optimize by maintaining a dedicated hashtable.
309     for (const auto& method : service_proxy->GetDescriptor().methods) {
310       if (req.method_name == method.name) {
311         decoded_reply = method.reply_proto_decoder(reply.reply_proto());
312         break;
313       }
314     }
315   }
316   const RequestID request_id = req.request_id;
317   invoking_method_reply_ = true;
318   service_proxy->EndInvoke(request_id, std::move(decoded_reply),
319                            reply.has_more());
320   invoking_method_reply_ = false;
321 
322   // If this is a streaming method and future replies will be resolved, put back
323   // the |req| with the callback into the set of active requests.
324   if (reply.has_more())
325     queued_requests_.emplace(request_id, std::move(req));
326 }
327 
328 ClientImpl::QueuedRequest::QueuedRequest() = default;
329 
TakeReceivedFD()330 base::ScopedFile ClientImpl::TakeReceivedFD() {
331   return std::move(received_fd_);
332 }
333 
334 }  // namespace ipc
335 }  // namespace perfetto
336