1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/cloud_trace_processor/orchestrator_impl.h"
18 
19 #include <memory>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "perfetto/base/status.h"
25 #include "perfetto/ext/base/flat_hash_map.h"
26 #include "perfetto/ext/base/status_or.h"
27 #include "perfetto/ext/base/threading/future.h"
28 #include "perfetto/ext/base/threading/stream.h"
29 #include "perfetto/ext/cloud_trace_processor/worker.h"
30 #include "protos/perfetto/cloud_trace_processor/common.pb.h"
31 #include "protos/perfetto/cloud_trace_processor/orchestrator.pb.h"
32 #include "protos/perfetto/cloud_trace_processor/worker.pb.h"
33 #include "src/trace_processor/util/status_macros.h"
34 
35 namespace perfetto {
36 namespace cloud_trace_processor {
37 namespace {
38 
CreateResponseToStatus(base::StatusOr<protos::TracePoolShardCreateResponse> response_or)39 base::Future<base::Status> CreateResponseToStatus(
40     base::StatusOr<protos::TracePoolShardCreateResponse> response_or) {
41   return response_or.status();
42 }
43 
SetTracesResponseToStatus(base::StatusOr<protos::TracePoolShardSetTracesResponse> response_or)44 base::Future<base::Status> SetTracesResponseToStatus(
45     base::StatusOr<protos::TracePoolShardSetTracesResponse> response_or) {
46   return response_or.status();
47 }
48 
49 base::Future<base::StatusOr<protos::TracePoolQueryResponse>>
RpcResponseToPoolResponse(base::StatusOr<protos::TracePoolShardQueryResponse> resp)50 RpcResponseToPoolResponse(
51     base::StatusOr<protos::TracePoolShardQueryResponse> resp) {
52   RETURN_IF_ERROR(resp.status());
53   protos::TracePoolQueryResponse ret;
54   ret.set_trace(std::move(resp->trace()));
55   *ret.mutable_result() = std::move(*resp->mutable_result());
56   return ret;
57 }
58 
59 base::StatusOrStream<protos::TracePoolShardSetTracesResponse>
RoundRobinSetTraces(const std::vector<std::unique_ptr<Worker>> & workers,const std::vector<std::string> & traces)60 RoundRobinSetTraces(const std::vector<std::unique_ptr<Worker>>& workers,
61                     const std::vector<std::string>& traces) {
62   uint32_t worker_idx = 0;
63   std::vector<protos::TracePoolShardSetTracesArgs> protos;
64   protos.resize(workers.size());
65   for (const auto& trace : traces) {
66     protos[worker_idx].add_traces(trace);
67     worker_idx = (worker_idx + 1) % workers.size();
68   }
69 
70   using ShardResponse = protos::TracePoolShardSetTracesResponse;
71   std::vector<base::StatusOrStream<ShardResponse>> streams;
72   for (uint32_t i = 0; i < protos.size(); ++i) {
73     streams.emplace_back(workers[i]->TracePoolShardSetTraces(protos[i]));
74   }
75   return base::FlattenStreams(std::move(streams));
76 }
77 }  // namespace
78 
79 Orchestrator::~Orchestrator() = default;
80 
CreateInProcess(std::vector<std::unique_ptr<Worker>> workers)81 std::unique_ptr<Orchestrator> Orchestrator::CreateInProcess(
82     std::vector<std::unique_ptr<Worker>> workers) {
83   return std::unique_ptr<Orchestrator>(
84       new OrchestratorImpl(std::move(workers)));
85 }
86 
OrchestratorImpl(std::vector<std::unique_ptr<Worker>> workers)87 OrchestratorImpl::OrchestratorImpl(std::vector<std::unique_ptr<Worker>> workers)
88     : workers_(std::move(workers)) {}
89 
90 base::StatusOrFuture<protos::TracePoolCreateResponse>
TracePoolCreate(const protos::TracePoolCreateArgs & args)91 OrchestratorImpl::TracePoolCreate(const protos::TracePoolCreateArgs& args) {
92   if (args.pool_type() != protos::TracePoolType::SHARED) {
93     return base::StatusOr<protos::TracePoolCreateResponse>(
94         base::ErrStatus("Currently only SHARED pools are supported"));
95   }
96   if (!args.has_shared_pool_name()) {
97     return base::StatusOr<protos::TracePoolCreateResponse>(
98         base::ErrStatus("Pool name must be provided for SHARED pools"));
99   }
100 
101   std::string id = "shared:" + args.shared_pool_name();
102   TracePool* exist = pools_.Find(id);
103   if (exist) {
104     return base::StatusOr<protos::TracePoolCreateResponse>(
105         base::ErrStatus("Pool %s already exists", id.c_str()));
106   }
107   protos::TracePoolShardCreateArgs group_args;
108   group_args.set_pool_id(id);
109   group_args.set_pool_type(args.pool_type());
110 
111   using ShardResponse = protos::TracePoolShardCreateResponse;
112   std::vector<base::StatusOrStream<ShardResponse>> shards;
113   for (uint32_t i = 0; i < workers_.size(); ++i) {
114     shards.emplace_back(
115         base::StreamFromFuture(workers_[i]->TracePoolShardCreate(group_args)));
116   }
117   return base::FlattenStreams(std::move(shards))
118       .MapFuture(&CreateResponseToStatus)
119       .Collect(base::AllOkCollector())
120       .ContinueWith(
121           [this, id](base::StatusOr<ShardResponse> resp)
122               -> base::StatusOrFuture<protos::TracePoolCreateResponse> {
123             RETURN_IF_ERROR(resp.status());
124             auto it_and_inserted = pools_.Insert(id, TracePool());
125             if (!it_and_inserted.second) {
126               return base::ErrStatus("Unable to insert pool %s", id.c_str());
127             }
128             return protos::TracePoolCreateResponse();
129           });
130 }
131 
132 base::StatusOrFuture<protos::TracePoolSetTracesResponse>
TracePoolSetTraces(const protos::TracePoolSetTracesArgs & args)133 OrchestratorImpl::TracePoolSetTraces(
134     const protos::TracePoolSetTracesArgs& args) {
135   std::string id = args.pool_id();
136   TracePool* pool = pools_.Find(id);
137   if (!pool) {
138     return base::StatusOr<protos::TracePoolSetTracesResponse>(
139         base::ErrStatus("Unable to find pool %s", id.c_str()));
140   }
141   if (!pool->loaded_traces.empty()) {
142     return base::StatusOr<protos::TracePoolSetTracesResponse>(base::ErrStatus(
143         "Incrementally adding/removing items to pool not currently supported"));
144   }
145   pool->loaded_traces.assign(args.traces().begin(), args.traces().end());
146   return RoundRobinSetTraces(workers_, pool->loaded_traces)
147       .MapFuture(&SetTracesResponseToStatus)
148       .Collect(base::AllOkCollector())
149       .ContinueWith(
150           [](base::Status status)
151               -> base::StatusOrFuture<protos::TracePoolSetTracesResponse> {
152             RETURN_IF_ERROR(status);
153             return protos::TracePoolSetTracesResponse();
154           });
155 }
156 
157 base::StatusOrStream<protos::TracePoolQueryResponse>
TracePoolQuery(const protos::TracePoolQueryArgs & args)158 OrchestratorImpl::TracePoolQuery(const protos::TracePoolQueryArgs& args) {
159   TracePool* pool = pools_.Find(args.pool_id());
160   if (!pool) {
161     return base::StreamOf(base::StatusOr<protos::TracePoolQueryResponse>(
162         base::ErrStatus("Unable to find pool %s", args.pool_id().c_str())));
163   }
164   protos::TracePoolShardQueryArgs shard_args;
165   *shard_args.mutable_pool_id() = args.pool_id();
166   *shard_args.mutable_sql_query() = args.sql_query();
167 
168   using ShardResponse = protos::TracePoolShardQueryResponse;
169   std::vector<base::StatusOrStream<ShardResponse>> streams;
170   for (uint32_t i = 0; i < workers_.size(); ++i) {
171     streams.emplace_back(workers_[i]->TracePoolShardQuery(shard_args));
172   }
173   return base::FlattenStreams(std::move(streams))
174       .MapFuture(&RpcResponseToPoolResponse);
175 }
176 
177 base::StatusOrFuture<protos::TracePoolDestroyResponse>
TracePoolDestroy(const protos::TracePoolDestroyArgs & args)178 OrchestratorImpl::TracePoolDestroy(const protos::TracePoolDestroyArgs& args) {
179   std::string id = args.pool_id();
180   TracePool* pool = pools_.Find(id);
181   if (!pool) {
182     return base::StatusOr<protos::TracePoolDestroyResponse>(
183         base::ErrStatus("Unable to find pool %s", id.c_str()));
184   }
185   protos::TracePoolShardDestroyArgs shard_args;
186   *shard_args.mutable_pool_id() = id;
187 
188   using ShardResponse = protos::TracePoolShardDestroyResponse;
189   std::vector<base::StatusOrStream<ShardResponse>> streams;
190   for (uint32_t i = 0; i < workers_.size(); ++i) {
191     streams.emplace_back(
192         base::StreamFromFuture(workers_[i]->TracePoolShardDestroy(shard_args)));
193   }
194   return base::FlattenStreams(std::move(streams))
195       .MapFuture(
196           [](base::StatusOr<ShardResponse> resp) -> base::Future<base::Status> {
197             return resp.status();
198           })
199       .Collect(base::AllOkCollector())
200       .ContinueWith(
201           [this, id](base::Status status)
202               -> base::StatusOrFuture<protos::TracePoolDestroyResponse> {
203             RETURN_IF_ERROR(status);
204             PERFETTO_CHECK(pools_.Erase(id));
205             return protos::TracePoolDestroyResponse();
206           });
207 }
208 
209 }  // namespace cloud_trace_processor
210 }  // namespace perfetto
211