1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/trace_processor/perfetto_sql/engine/created_function.h"
18
19 #include <cstddef>
20 #include <queue>
21 #include <stack>
22
23 #include "perfetto/base/status.h"
24 #include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
25 #include "src/trace_processor/perfetto_sql/parser/function_util.h"
26 #include "src/trace_processor/sqlite/scoped_db.h"
27 #include "src/trace_processor/sqlite/sql_source.h"
28 #include "src/trace_processor/sqlite/sqlite_engine.h"
29 #include "src/trace_processor/sqlite/sqlite_utils.h"
30 #include "src/trace_processor/tp_metatrace.h"
31 #include "src/trace_processor/util/status_macros.h"
32
33 namespace perfetto {
34 namespace trace_processor {
35
36 namespace {
37
CheckNoMoreRows(sqlite3_stmt * stmt,sqlite3 * db,const FunctionPrototype & prototype)38 base::Status CheckNoMoreRows(sqlite3_stmt* stmt,
39 sqlite3* db,
40 const FunctionPrototype& prototype) {
41 int ret = sqlite3_step(stmt);
42 RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
43 if (ret == SQLITE_ROW) {
44 auto expanded_sql = ScopedSqliteString(sqlite3_expanded_sql(stmt));
45 return base::ErrStatus(
46 "%s: multiple values were returned when executing function body. "
47 "Executed SQL was %s",
48 prototype.function_name.c_str(), expanded_sql.get());
49 }
50 PERFETTO_DCHECK(ret == SQLITE_DONE);
51 return base::OkStatus();
52 }
53
54 // Note: if the returned type is string / bytes, it will be invalidated by the
55 // next call to SQLite, so the caller must take care to either copy or use the
56 // value before calling SQLite again.
EvaluateScalarStatement(sqlite3_stmt * stmt,sqlite3 * db,const FunctionPrototype & prototype)57 base::StatusOr<SqlValue> EvaluateScalarStatement(
58 sqlite3_stmt* stmt,
59 sqlite3* db,
60 const FunctionPrototype& prototype) {
61 int ret = sqlite3_step(stmt);
62 RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
63 if (ret == SQLITE_DONE) {
64 // No return value means we just return don't set |out|.
65 return SqlValue();
66 }
67
68 PERFETTO_DCHECK(ret == SQLITE_ROW);
69 size_t col_count = static_cast<size_t>(sqlite3_column_count(stmt));
70 if (col_count != 1) {
71 return base::ErrStatus(
72 "%s: SQL definition should only return one column: returned %zu "
73 "columns",
74 prototype.function_name.c_str(), col_count);
75 }
76
77 SqlValue result =
78 sqlite::utils::SqliteValueToSqlValue(sqlite3_column_value(stmt, 0));
79
80 // If we return a bytes type but have a null pointer, SQLite will convert this
81 // to an SQL null. However, for proto build functions, we actively want to
82 // distinguish between nulls and 0 byte strings. Therefore, change the value
83 // to an empty string.
84 if (result.type == SqlValue::kBytes && result.bytes_value == nullptr) {
85 PERFETTO_DCHECK(result.bytes_count == 0);
86 result.bytes_value = "";
87 }
88
89 return result;
90 }
91
BindArguments(sqlite3_stmt * stmt,const FunctionPrototype & prototype,size_t argc,sqlite3_value ** argv)92 base::Status BindArguments(sqlite3_stmt* stmt,
93 const FunctionPrototype& prototype,
94 size_t argc,
95 sqlite3_value** argv) {
96 // Bind all the arguments to the appropriate places in the function.
97 for (size_t i = 0; i < argc; ++i) {
98 RETURN_IF_ERROR(MaybeBindArgument(stmt, prototype.function_name,
99 prototype.arguments[i], argv[i]));
100 }
101 return base::OkStatus();
102 }
103
104 struct StoredSqlValue {
105 // unique_ptr to ensure that the pointers to these values are long-lived.
106 using OwnedString = std::unique_ptr<std::string>;
107 using OwnedBytes = std::unique_ptr<std::vector<uint8_t>>;
108 // variant is a pain to use, but it's the simplest way to ensure that
109 // the destructors run correctly for non-trivial members of the
110 // union.
111 using Data =
112 std::variant<int64_t, double, OwnedString, OwnedBytes, std::nullptr_t>;
113
StoredSqlValueperfetto::trace_processor::__anon9538c1d00111::StoredSqlValue114 StoredSqlValue(SqlValue value) {
115 switch (value.type) {
116 case SqlValue::Type::kNull:
117 data = nullptr;
118 break;
119 case SqlValue::Type::kLong:
120 data = value.long_value;
121 break;
122 case SqlValue::Type::kDouble:
123 data = value.double_value;
124 break;
125 case SqlValue::Type::kString:
126 data = std::make_unique<std::string>(value.string_value);
127 break;
128 case SqlValue::Type::kBytes:
129 const uint8_t* ptr = static_cast<const uint8_t*>(value.bytes_value);
130 data = std::make_unique<std::vector<uint8_t>>(ptr,
131 ptr + value.bytes_count);
132 break;
133 }
134 }
135
AsSqlValueperfetto::trace_processor::__anon9538c1d00111::StoredSqlValue136 SqlValue AsSqlValue() {
137 if (std::holds_alternative<std::nullptr_t>(data)) {
138 return SqlValue();
139 } else if (std::holds_alternative<int64_t>(data)) {
140 return SqlValue::Long(std::get<int64_t>(data));
141 } else if (std::holds_alternative<double>(data)) {
142 return SqlValue::Double(std::get<double>(data));
143 } else if (std::holds_alternative<OwnedString>(data)) {
144 const auto& str_ptr = std::get<OwnedString>(data);
145 return SqlValue::String(str_ptr->c_str());
146 } else if (std::holds_alternative<OwnedBytes>(data)) {
147 const auto& bytes_ptr = std::get<OwnedBytes>(data);
148 return SqlValue::Bytes(bytes_ptr->data(), bytes_ptr->size());
149 }
150 // GCC doesn't realize that the switch is exhaustive.
151 PERFETTO_CHECK(false);
152 return SqlValue();
153 }
154
155 Data data = nullptr;
156 };
157
158 class Memoizer {
159 public:
160 // Supported arguments. For now, only functions with a single int argument are
161 // supported.
162 using MemoizedArgs = int64_t;
163
164 // Enables memoization.
165 // Only functions with a single int argument returning ints are supported.
EnableMemoization(const FunctionPrototype & prototype)166 base::Status EnableMemoization(const FunctionPrototype& prototype) {
167 if (prototype.arguments.size() != 1 ||
168 TypeToSqlValueType(prototype.arguments[0].type()) !=
169 SqlValue::Type::kLong) {
170 return base::ErrStatus(
171 "EXPERIMENTAL_MEMOIZE: Function %s should take one int argument",
172 prototype.function_name.c_str());
173 }
174 enabled_ = true;
175 return base::OkStatus();
176 }
177
178 // Returns the memoized value for the current invocation if it exists.
GetMemoizedValue(MemoizedArgs args)179 std::optional<SqlValue> GetMemoizedValue(MemoizedArgs args) {
180 if (!enabled_) {
181 return std::nullopt;
182 }
183 StoredSqlValue* value = memoized_values_.Find(args);
184 if (!value) {
185 return std::nullopt;
186 }
187 return value->AsSqlValue();
188 }
189
HasMemoizedValue(MemoizedArgs args)190 bool HasMemoizedValue(MemoizedArgs args) {
191 return GetMemoizedValue(args).has_value();
192 }
193
194 // Saves the return value of the current invocation for memoization.
Memoize(MemoizedArgs args,SqlValue value)195 void Memoize(MemoizedArgs args, SqlValue value) {
196 if (!enabled_) {
197 return;
198 }
199 memoized_values_.Insert(args, StoredSqlValue(value));
200 }
201
202 // Checks that the function has a single int argument and returns it.
AsMemoizedArgs(size_t argc,sqlite3_value ** argv)203 static std::optional<MemoizedArgs> AsMemoizedArgs(size_t argc,
204 sqlite3_value** argv) {
205 if (argc != 1) {
206 return std::nullopt;
207 }
208 SqlValue arg = sqlite::utils::SqliteValueToSqlValue(argv[0]);
209 if (arg.type != SqlValue::Type::kLong) {
210 return std::nullopt;
211 }
212 return arg.AsLong();
213 }
214
enabled() const215 bool enabled() const { return enabled_; }
216
217 private:
218 bool enabled_ = false;
219 base::FlatHashMap<MemoizedArgs, StoredSqlValue> memoized_values_;
220 };
221
222 // A helper to unroll recursive calls: to minimise the amount of stack space
223 // used, memoized recursive calls are evaluated using an on-heap queue.
224 //
225 // We compute the function in two passes:
226 // - In the first pass, we evaluate the statement to discover which recursive
227 // calls it makes, returning null from recursive calls and ignoring the
228 // result.
229 // - In the second pass, we evaluate the statement again, but this time we
230 // memoize the result of each recursive call.
231 //
232 // We maintain a queue for scheduled "first pass" calls and a stack for the
233 // scheduled "second pass" calls, evaluating available first pass calls, then
234 // second pass calls. When we evaluate a first pass call, the further calls to
235 // CreatedFunction::Run will just add it to the "first pass" queue. The second
236 // pass, however, will evaluate the function normally, typically just using the
237 // memoized result for the dependent calls. However, if the recursive calls
238 // depend on the return value of the function, we will proceed with normal
239 // recursion.
240 //
241 // To make it more concrete, consider an following example.
242 // We have a function computing factorial (f) and we want to compute f(3).
243 //
244 // SELECT create_function('f(x INT)', 'INT',
245 // 'SELECT IIF($x = 0, 1, $x * f($x - 1))');
246 // SELECT experimental_memoize('f');
247 // SELECT f(3);
248 //
249 // - We start with a call to f(3). It executes the statement as normal, which
250 // recursively calls f(2).
251 // - When f(2) is called, we detect that it is a recursive call and we start
252 // unrolling it, entering RecursiveCallUnroller::Run.
253 // - We schedule first pass for 2 and the state of the unroller
254 // is first_pass: [2], second_pass: [].
255 // - Then we compute the first pass for f(2). It calls f(1), which is ignored
256 // due to OnFunctionCall returning kIgnoreDueToFirstPass and 1 is added to the
257 // first pass queue. 2 is taked out of the first pass queue and moved to the
258 // second pass stack. State: first_pass: [1], second_pass: [2].
259 // - Then we compute the first pass for 1. The similar thing happens: f(0) is
260 // called and ignored, 0 is added to first_pass, 1 is added to second_pass.
261 // State: first_pass: [0], second_pass: [2, 1].
262 // - Then we compute the first pass for 0. It doesn't make further calls, so
263 // 0 is moved to the second pass stack.
264 // State: first_pass: [], second_pass: [2, 1, 0].
265 // - Then we compute the second pass for 0. It just returns 1.
266 // State: first_pass: [], second_pass: [2, 1], results: {0: 1}.
267 // - Then we compute the second pass for 1. It calls f(0), which is memoized.
268 // State: first_pass: [], second_pass: [2], results: {0: 1, 1: 1}.
269 // - Then we compute the second pass for 1. It calls f(1), which is memoized.
270 // State: first_pass: [], second_pass: [], results: {0: 1, 1: 1, 2: 2}.
271 // - As both first_pass and second_pass are empty, we return from
272 // RecursiveCallUnroller::Run.
273 // - Control is returned to CreatedFunction::Run for f(2), which returns
274 // memoized value.
275 // - Then control is returned to CreatedFunction::Run for f(3), which completes
276 // the computation.
277 class RecursiveCallUnroller {
278 public:
RecursiveCallUnroller(PerfettoSqlEngine * engine,sqlite3_stmt * stmt,const FunctionPrototype & prototype,Memoizer & memoizer)279 RecursiveCallUnroller(PerfettoSqlEngine* engine,
280 sqlite3_stmt* stmt,
281 const FunctionPrototype& prototype,
282 Memoizer& memoizer)
283 : engine_(engine),
284 stmt_(stmt),
285 prototype_(prototype),
286 memoizer_(memoizer) {}
287
288 // Whether we should just return null due to us being in the "first pass".
289 enum class FunctionCallState {
290 kIgnoreDueToFirstPass,
291 kEvaluate,
292 };
293
OnFunctionCall(Memoizer::MemoizedArgs args)294 base::StatusOr<FunctionCallState> OnFunctionCall(
295 Memoizer::MemoizedArgs args) {
296 // If we are in the second pass, we just continue the function execution,
297 // including checking if a memoized value is available and returning it.
298 //
299 // We generally expect a memoized value to be available, but there are
300 // cases when it might not be the case, e.g. when which recursive calls are
301 // made depends on the return value of the function, e.g. for the following
302 // function, the first pass will not detect f(y) calls, so they will
303 // be computed recursively.
304 // f(x): SELECT max(f(y)) FROM y WHERE y < f($x - 1);
305 if (state_ == State::kComputingSecondPass) {
306 return FunctionCallState::kEvaluate;
307 }
308 if (!memoizer_.HasMemoizedValue(args)) {
309 ArgState* state = visited_.Find(args);
310 if (state) {
311 // Detect recursive loops, e.g. f(1) calling f(2) calling f(1).
312 if (*state == ArgState::kEvaluating) {
313 return base::ErrStatus("Infinite recursion detected");
314 }
315 } else {
316 visited_.Insert(args, ArgState::kScheduled);
317 first_pass_.push(args);
318 }
319 }
320 return FunctionCallState::kIgnoreDueToFirstPass;
321 }
322
Run(Memoizer::MemoizedArgs initial_args)323 base::Status Run(Memoizer::MemoizedArgs initial_args) {
324 PERFETTO_TP_TRACE(metatrace::Category::FUNCTION_CALL,
325 "UNROLL_RECURSIVE_FUNCTION_CALL",
326 [&](metatrace::Record* r) {
327 r->AddArg("Function", prototype_.function_name);
328 r->AddArg("Arg 0", std::to_string(initial_args));
329 });
330
331 first_pass_.push(initial_args);
332 visited_.Insert(initial_args, ArgState::kScheduled);
333
334 while (!first_pass_.empty() || !second_pass_.empty()) {
335 // If we have scheduled first pass calls, we evaluate them first.
336 if (!first_pass_.empty()) {
337 state_ = State::kComputingFirstPass;
338 Memoizer::MemoizedArgs args = first_pass_.front();
339
340 PERFETTO_TP_TRACE(metatrace::Category::FUNCTION_CALL,
341 "SQL_FUNCTION_CALL", [&](metatrace::Record* r) {
342 r->AddArg("Function", prototype_.function_name);
343 r->AddArg("Type", "UnrollRecursiveCall_FirstPass");
344 r->AddArg("Arg 0", std::to_string(args));
345 });
346
347 first_pass_.pop();
348 second_pass_.push(args);
349 Evaluate(args).status();
350 continue;
351 }
352
353 state_ = State::kComputingSecondPass;
354 Memoizer::MemoizedArgs args = second_pass_.top();
355
356 PERFETTO_TP_TRACE(metatrace::Category::FUNCTION_CALL, "SQL_FUNCTION_CALL",
357 [&](metatrace::Record* r) {
358 r->AddArg("Function", prototype_.function_name);
359 r->AddArg("Type", "UnrollRecursiveCall_SecondPass");
360 r->AddArg("Arg 0", std::to_string(args));
361 });
362
363 visited_.Insert(args, ArgState::kEvaluating);
364 second_pass_.pop();
365 base::StatusOr<std::optional<int64_t>> result = Evaluate(args);
366 RETURN_IF_ERROR(result.status());
367 std::optional<int64_t> maybe_int_result = result.value();
368 if (!maybe_int_result.has_value()) {
369 continue;
370 }
371 visited_.Insert(args, ArgState::kEvaluated);
372 memoizer_.Memoize(args, SqlValue::Long(*maybe_int_result));
373 }
374 return base::OkStatus();
375 }
376
377 private:
378 // This function returns:
379 // - base::ErrStatus if the evaluation of the function failed.
380 // - std::nullopt if the function returned a non-integer value.
381 // - the result of the function otherwise.
Evaluate(Memoizer::MemoizedArgs args)382 base::StatusOr<std::optional<int64_t>> Evaluate(Memoizer::MemoizedArgs args) {
383 RETURN_IF_ERROR(MaybeBindIntArgument(stmt_, prototype_.function_name,
384 prototype_.arguments[0], args));
385 base::StatusOr<SqlValue> result = EvaluateScalarStatement(
386 stmt_, engine_->sqlite_engine()->db(), prototype_);
387 sqlite3_reset(stmt_);
388 sqlite3_clear_bindings(stmt_);
389 RETURN_IF_ERROR(result.status());
390 if (result->type != SqlValue::Type::kLong) {
391 return std::optional<int64_t>(std::nullopt);
392 }
393 return std::optional<int64_t>(result->long_value);
394 }
395
396 PerfettoSqlEngine* engine_;
397 sqlite3_stmt* stmt_;
398 const FunctionPrototype& prototype_;
399 Memoizer& memoizer_;
400
401 // Current state of the evaluation.
402 enum class State {
403 kComputingFirstPass,
404 kComputingSecondPass,
405 };
406 State state_ = State::kComputingFirstPass;
407
408 // A state of evaluation of a given argument.
409 enum class ArgState {
410 kScheduled,
411 kEvaluating,
412 kEvaluated,
413 };
414
415 // See the class-level comment for the explanation of the two passes.
416 std::queue<Memoizer::MemoizedArgs> first_pass_;
417 base::FlatHashMap<Memoizer::MemoizedArgs, ArgState> visited_;
418 std::stack<Memoizer::MemoizedArgs> second_pass_;
419 };
420
421 } // namespace
422
423 // This class is used to store the state of a CREATE_FUNCTION call.
424 // It is used to store the state of the function across multiple invocations
425 // of the function (e.g. when the function is called recursively).
426 class State : public CreatedFunction::Context {
427 public:
State(PerfettoSqlEngine * engine)428 explicit State(PerfettoSqlEngine* engine) : engine_(engine) {}
429 ~State() override;
430
431 // Prepare a statement and push it into the stack of allocated statements
432 // for this function.
PrepareStatement()433 base::Status PrepareStatement() {
434 SqliteEngine::PreparedStatement stmt =
435 engine_->sqlite_engine()->PrepareStatement(*sql_);
436 RETURN_IF_ERROR(stmt.status());
437 is_valid_ = true;
438 stmts_.push_back(std::move(stmt));
439 return base::OkStatus();
440 }
441
442 // Sets the state of the function. Should be called only when the function
443 // is invalid (i.e. when it is first created or when the previous statement
444 // failed to prepare).
Reset(FunctionPrototype prototype,sql_argument::Type return_type,SqlSource sql)445 void Reset(FunctionPrototype prototype,
446 sql_argument::Type return_type,
447 SqlSource sql) {
448 // Re-registration of valid functions is not allowed.
449 PERFETTO_DCHECK(!is_valid_);
450 PERFETTO_DCHECK(stmts_.empty());
451
452 prototype_ = std::move(prototype);
453 return_type_ = return_type;
454 sql_ = std::move(sql);
455 }
456
457 // This function is called each time the function is called.
458 // It ensures that we have a statement for the current recursion level,
459 // allocating a new one if needed.
PushStackEntry()460 base::Status PushStackEntry() {
461 ++current_recursion_level_;
462 if (current_recursion_level_ > stmts_.size()) {
463 return PrepareStatement();
464 }
465 return base::OkStatus();
466 }
467
468 // Returns the statement that is used for the current invocation.
CurrentStatement()469 sqlite3_stmt* CurrentStatement() {
470 return stmts_[current_recursion_level_ - 1].sqlite_stmt();
471 }
472
473 // This function is called each time the function returns and resets the
474 // statement that this invocation used.
PopStackEntry()475 void PopStackEntry() {
476 if (current_recursion_level_ > stmts_.size()) {
477 // This is possible if we didn't prepare the statement and returned
478 // an error.
479 return;
480 }
481 sqlite3_reset(CurrentStatement());
482 sqlite3_clear_bindings(CurrentStatement());
483 --current_recursion_level_;
484 }
485
OnFunctionCall(Memoizer::MemoizedArgs args)486 base::StatusOr<RecursiveCallUnroller::FunctionCallState> OnFunctionCall(
487 Memoizer::MemoizedArgs args) {
488 if (!recursive_call_unroller_) {
489 return RecursiveCallUnroller::FunctionCallState::kEvaluate;
490 }
491 return recursive_call_unroller_->OnFunctionCall(args);
492 }
493
494 // Called before checking the function for memoization.
UnrollRecursiveCallIfNeeded(Memoizer::MemoizedArgs args)495 base::Status UnrollRecursiveCallIfNeeded(Memoizer::MemoizedArgs args) {
496 if (!memoizer_.enabled() || !is_in_recursive_call() ||
497 recursive_call_unroller_) {
498 return base::OkStatus();
499 }
500 // If we are in a recursive call, we need to check if we have already
501 // computed the result for the current arguments.
502 if (memoizer_.HasMemoizedValue(args)) {
503 return base::OkStatus();
504 }
505
506 // If we are in a beginning of a function call:
507 // - is a recursive,
508 // - can be memoized,
509 // - hasn't been memoized already, and
510 // - hasn't start unrolling yet;
511 // start the unrolling and run the unrolling loop.
512 recursive_call_unroller_ = std::make_unique<RecursiveCallUnroller>(
513 engine_, CurrentStatement(), prototype_, memoizer_);
514 auto status = recursive_call_unroller_->Run(args);
515 recursive_call_unroller_.reset();
516 return status;
517 }
518
519 // Schedule a statement to be validated that it is indeed doesn't have any
520 // more rows.
ScheduleEmptyStatementValidation(sqlite3_stmt * stmt)521 void ScheduleEmptyStatementValidation(sqlite3_stmt* stmt) {
522 empty_stmts_to_validate_.push_back(stmt);
523 }
524
ValidateEmptyStatements()525 base::Status ValidateEmptyStatements() {
526 while (!empty_stmts_to_validate_.empty()) {
527 sqlite3_stmt* stmt = empty_stmts_to_validate_.back();
528 empty_stmts_to_validate_.pop_back();
529 RETURN_IF_ERROR(
530 CheckNoMoreRows(stmt, engine_->sqlite_engine()->db(), prototype_));
531 }
532 return base::OkStatus();
533 }
534
is_in_recursive_call() const535 bool is_in_recursive_call() const { return current_recursion_level_ > 1; }
536
EnableMemoization()537 base::Status EnableMemoization() {
538 return memoizer_.EnableMemoization(prototype_);
539 }
540
engine() const541 PerfettoSqlEngine* engine() const { return engine_; }
542
prototype() const543 const FunctionPrototype& prototype() const { return prototype_; }
544
return_type() const545 sql_argument::Type return_type() const { return return_type_; }
546
sql() const547 const std::string& sql() const { return sql_->sql(); }
548
is_valid() const549 bool is_valid() const { return is_valid_; }
550
memoizer()551 Memoizer& memoizer() { return memoizer_; }
552
553 private:
554 PerfettoSqlEngine* engine_;
555 FunctionPrototype prototype_;
556 sql_argument::Type return_type_;
557 std::optional<SqlSource> sql_;
558 // Perfetto SQL functions support recursion. Given that each function call in
559 // the stack requires a dedicated statement, we maintain a stack of prepared
560 // statements and use the top one for each new call (allocating a new one if
561 // needed).
562 std::vector<SqliteEngine::PreparedStatement> stmts_;
563 // A list of statements to verify to ensure that they don't have more rows
564 // in VerifyPostConditions.
565 std::vector<sqlite3_stmt*> empty_stmts_to_validate_;
566 size_t current_recursion_level_ = 0;
567 // Function re-registration is not allowed, but the user is allowed to define
568 // the function again if the first call failed. |is_valid_| flag helps that
569 // by tracking whether the current function definition is valid (in which case
570 // re-registration is not allowed).
571 bool is_valid_ = false;
572 Memoizer memoizer_;
573 // Set if we are in a middle of unrolling a recursive call.
574 std::unique_ptr<RecursiveCallUnroller> recursive_call_unroller_;
575 };
576
577 State::~State() = default;
578
MakeContext(PerfettoSqlEngine * engine)579 std::unique_ptr<CreatedFunction::Context> CreatedFunction::MakeContext(
580 PerfettoSqlEngine* engine) {
581 return std::make_unique<State>(engine);
582 }
583
IsValid(Context * ctx)584 bool CreatedFunction::IsValid(Context* ctx) {
585 return static_cast<State*>(ctx)->is_valid();
586 }
587
Reset(Context * ctx,PerfettoSqlEngine * engine)588 void CreatedFunction::Reset(Context* ctx, PerfettoSqlEngine* engine) {
589 ctx->~Context();
590 new (ctx) State(engine);
591 }
592
Run(CreatedFunction::Context * ctx,size_t argc,sqlite3_value ** argv,SqlValue & out,Destructors &)593 base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
594 size_t argc,
595 sqlite3_value** argv,
596 SqlValue& out,
597 Destructors&) {
598 State* state = static_cast<State*>(ctx);
599
600 // Enter the function and ensure that we have a statement allocated.
601 RETURN_IF_ERROR(state->PushStackEntry());
602
603 if (argc != state->prototype().arguments.size()) {
604 return base::ErrStatus(
605 "%s: invalid number of args; expected %zu, received %zu",
606 state->prototype().function_name.c_str(),
607 state->prototype().arguments.size(), argc);
608 }
609
610 // Type check all the arguments.
611 for (size_t i = 0; i < argc; ++i) {
612 sqlite3_value* arg = argv[i];
613 sql_argument::Type type = state->prototype().arguments[i].type();
614 base::Status status = sqlite::utils::TypeCheckSqliteValue(
615 arg, sql_argument::TypeToSqlValueType(type),
616 sql_argument::TypeToHumanFriendlyString(type));
617 if (!status.ok()) {
618 return base::ErrStatus("%s[arg=%s]: argument %zu %s",
619 state->prototype().function_name.c_str(),
620 sqlite3_value_text(arg), i, status.c_message());
621 }
622 }
623
624 std::optional<Memoizer::MemoizedArgs> memoized_args =
625 Memoizer::AsMemoizedArgs(argc, argv);
626
627 if (memoized_args) {
628 // If we are in the middle of an recursive calls unrolling, we might want to
629 // ignore the function invocation. See the comment in RecursiveCallUnroller
630 // for more details.
631 base::StatusOr<RecursiveCallUnroller::FunctionCallState> unroll_state =
632 state->OnFunctionCall(*memoized_args);
633 RETURN_IF_ERROR(unroll_state.status());
634 if (*unroll_state ==
635 RecursiveCallUnroller::FunctionCallState::kIgnoreDueToFirstPass) {
636 // Return NULL.
637 return base::OkStatus();
638 }
639
640 RETURN_IF_ERROR(state->UnrollRecursiveCallIfNeeded(*memoized_args));
641
642 std::optional<SqlValue> memoized_value =
643 state->memoizer().GetMemoizedValue(*memoized_args);
644 if (memoized_value) {
645 out = *memoized_value;
646 return base::OkStatus();
647 }
648 }
649
650 PERFETTO_TP_TRACE(
651 metatrace::Category::FUNCTION_CALL, "SQL_FUNCTION_CALL",
652 [state, argv](metatrace::Record* r) {
653 r->AddArg("Function", state->prototype().function_name.c_str());
654 for (uint32_t i = 0; i < state->prototype().arguments.size(); ++i) {
655 std::string key = "Arg " + std::to_string(i);
656 const char* value =
657 reinterpret_cast<const char*>(sqlite3_value_text(argv[i]));
658 r->AddArg(base::StringView(key),
659 value ? base::StringView(value) : base::StringView("NULL"));
660 }
661 });
662
663 RETURN_IF_ERROR(
664 BindArguments(state->CurrentStatement(), state->prototype(), argc, argv));
665 auto result = EvaluateScalarStatement(state->CurrentStatement(),
666 state->engine()->sqlite_engine()->db(),
667 state->prototype());
668 RETURN_IF_ERROR(result.status());
669 out = result.value();
670 state->ScheduleEmptyStatementValidation(state->CurrentStatement());
671
672 if (memoized_args) {
673 state->memoizer().Memoize(*memoized_args, out);
674 }
675
676 return base::OkStatus();
677 }
678
Cleanup(CreatedFunction::Context * ctx)679 void CreatedFunction::Cleanup(CreatedFunction::Context* ctx) {
680 // Clear the statement.
681 static_cast<State*>(ctx)->PopStackEntry();
682 }
683
VerifyPostConditions(CreatedFunction::Context * ctx)684 base::Status CreatedFunction::VerifyPostConditions(
685 CreatedFunction::Context* ctx) {
686 return static_cast<State*>(ctx)->ValidateEmptyStatements();
687 }
688
Prepare(CreatedFunction::Context * ctx,FunctionPrototype prototype,sql_argument::Type return_type,SqlSource source)689 base::Status CreatedFunction::Prepare(CreatedFunction::Context* ctx,
690 FunctionPrototype prototype,
691 sql_argument::Type return_type,
692 SqlSource source) {
693 State* state = static_cast<State*>(ctx);
694 state->Reset(std::move(prototype), return_type, std::move(source));
695
696 // Ideally, we would unregister the function here if the statement prep
697 // failed, but SQLite doesn't allow unregistering functions inside active
698 // statements. So instead we'll just try to prepare the statement when calling
699 // this function, which will return an error.
700 return state->PrepareStatement();
701 }
702
EnableMemoization(Context * ctx)703 base::Status CreatedFunction::EnableMemoization(Context* ctx) {
704 return static_cast<State*>(ctx)->EnableMemoization();
705 }
706
707 } // namespace trace_processor
708 } // namespace perfetto
709