• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/trace_processor/perfetto_sql/engine/created_function.h"
18 
19 #include <cstddef>
20 #include <queue>
21 #include <stack>
22 
23 #include "perfetto/base/status.h"
24 #include "src/trace_processor/perfetto_sql/engine/perfetto_sql_engine.h"
25 #include "src/trace_processor/perfetto_sql/parser/function_util.h"
26 #include "src/trace_processor/sqlite/scoped_db.h"
27 #include "src/trace_processor/sqlite/sql_source.h"
28 #include "src/trace_processor/sqlite/sqlite_engine.h"
29 #include "src/trace_processor/sqlite/sqlite_utils.h"
30 #include "src/trace_processor/tp_metatrace.h"
31 #include "src/trace_processor/util/status_macros.h"
32 
33 namespace perfetto {
34 namespace trace_processor {
35 
36 namespace {
37 
CheckNoMoreRows(sqlite3_stmt * stmt,sqlite3 * db,const FunctionPrototype & prototype)38 base::Status CheckNoMoreRows(sqlite3_stmt* stmt,
39                              sqlite3* db,
40                              const FunctionPrototype& prototype) {
41   int ret = sqlite3_step(stmt);
42   RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
43   if (ret == SQLITE_ROW) {
44     auto expanded_sql = ScopedSqliteString(sqlite3_expanded_sql(stmt));
45     return base::ErrStatus(
46         "%s: multiple values were returned when executing function body. "
47         "Executed SQL was %s",
48         prototype.function_name.c_str(), expanded_sql.get());
49   }
50   PERFETTO_DCHECK(ret == SQLITE_DONE);
51   return base::OkStatus();
52 }
53 
54 // Note: if the returned type is string / bytes, it will be invalidated by the
55 // next call to SQLite, so the caller must take care to either copy or use the
56 // value before calling SQLite again.
EvaluateScalarStatement(sqlite3_stmt * stmt,sqlite3 * db,const FunctionPrototype & prototype)57 base::StatusOr<SqlValue> EvaluateScalarStatement(
58     sqlite3_stmt* stmt,
59     sqlite3* db,
60     const FunctionPrototype& prototype) {
61   int ret = sqlite3_step(stmt);
62   RETURN_IF_ERROR(SqliteRetToStatus(db, prototype.function_name, ret));
63   if (ret == SQLITE_DONE) {
64     // No return value means we just return don't set |out|.
65     return SqlValue();
66   }
67 
68   PERFETTO_DCHECK(ret == SQLITE_ROW);
69   size_t col_count = static_cast<size_t>(sqlite3_column_count(stmt));
70   if (col_count != 1) {
71     return base::ErrStatus(
72         "%s: SQL definition should only return one column: returned %zu "
73         "columns",
74         prototype.function_name.c_str(), col_count);
75   }
76 
77   SqlValue result =
78       sqlite::utils::SqliteValueToSqlValue(sqlite3_column_value(stmt, 0));
79 
80   // If we return a bytes type but have a null pointer, SQLite will convert this
81   // to an SQL null. However, for proto build functions, we actively want to
82   // distinguish between nulls and 0 byte strings. Therefore, change the value
83   // to an empty string.
84   if (result.type == SqlValue::kBytes && result.bytes_value == nullptr) {
85     PERFETTO_DCHECK(result.bytes_count == 0);
86     result.bytes_value = "";
87   }
88 
89   return result;
90 }
91 
BindArguments(sqlite3_stmt * stmt,const FunctionPrototype & prototype,size_t argc,sqlite3_value ** argv)92 base::Status BindArguments(sqlite3_stmt* stmt,
93                            const FunctionPrototype& prototype,
94                            size_t argc,
95                            sqlite3_value** argv) {
96   // Bind all the arguments to the appropriate places in the function.
97   for (size_t i = 0; i < argc; ++i) {
98     RETURN_IF_ERROR(MaybeBindArgument(stmt, prototype.function_name,
99                                       prototype.arguments[i], argv[i]));
100   }
101   return base::OkStatus();
102 }
103 
104 struct StoredSqlValue {
105   // unique_ptr to ensure that the pointers to these values are long-lived.
106   using OwnedString = std::unique_ptr<std::string>;
107   using OwnedBytes = std::unique_ptr<std::vector<uint8_t>>;
108   // variant is a pain to use, but it's the simplest way to ensure that
109   // the destructors run correctly for non-trivial members of the
110   // union.
111   using Data =
112       std::variant<int64_t, double, OwnedString, OwnedBytes, std::nullptr_t>;
113 
StoredSqlValueperfetto::trace_processor::__anon9538c1d00111::StoredSqlValue114   StoredSqlValue(SqlValue value) {
115     switch (value.type) {
116       case SqlValue::Type::kNull:
117         data = nullptr;
118         break;
119       case SqlValue::Type::kLong:
120         data = value.long_value;
121         break;
122       case SqlValue::Type::kDouble:
123         data = value.double_value;
124         break;
125       case SqlValue::Type::kString:
126         data = std::make_unique<std::string>(value.string_value);
127         break;
128       case SqlValue::Type::kBytes:
129         const uint8_t* ptr = static_cast<const uint8_t*>(value.bytes_value);
130         data = std::make_unique<std::vector<uint8_t>>(ptr,
131                                                       ptr + value.bytes_count);
132         break;
133     }
134   }
135 
AsSqlValueperfetto::trace_processor::__anon9538c1d00111::StoredSqlValue136   SqlValue AsSqlValue() {
137     if (std::holds_alternative<std::nullptr_t>(data)) {
138       return SqlValue();
139     } else if (std::holds_alternative<int64_t>(data)) {
140       return SqlValue::Long(std::get<int64_t>(data));
141     } else if (std::holds_alternative<double>(data)) {
142       return SqlValue::Double(std::get<double>(data));
143     } else if (std::holds_alternative<OwnedString>(data)) {
144       const auto& str_ptr = std::get<OwnedString>(data);
145       return SqlValue::String(str_ptr->c_str());
146     } else if (std::holds_alternative<OwnedBytes>(data)) {
147       const auto& bytes_ptr = std::get<OwnedBytes>(data);
148       return SqlValue::Bytes(bytes_ptr->data(), bytes_ptr->size());
149     }
150     // GCC doesn't realize that the switch is exhaustive.
151     PERFETTO_CHECK(false);
152     return SqlValue();
153   }
154 
155   Data data = nullptr;
156 };
157 
158 class Memoizer {
159  public:
160   // Supported arguments. For now, only functions with a single int argument are
161   // supported.
162   using MemoizedArgs = int64_t;
163 
164   // Enables memoization.
165   // Only functions with a single int argument returning ints are supported.
EnableMemoization(const FunctionPrototype & prototype)166   base::Status EnableMemoization(const FunctionPrototype& prototype) {
167     if (prototype.arguments.size() != 1 ||
168         TypeToSqlValueType(prototype.arguments[0].type()) !=
169             SqlValue::Type::kLong) {
170       return base::ErrStatus(
171           "EXPERIMENTAL_MEMOIZE: Function %s should take one int argument",
172           prototype.function_name.c_str());
173     }
174     enabled_ = true;
175     return base::OkStatus();
176   }
177 
178   // Returns the memoized value for the current invocation if it exists.
GetMemoizedValue(MemoizedArgs args)179   std::optional<SqlValue> GetMemoizedValue(MemoizedArgs args) {
180     if (!enabled_) {
181       return std::nullopt;
182     }
183     StoredSqlValue* value = memoized_values_.Find(args);
184     if (!value) {
185       return std::nullopt;
186     }
187     return value->AsSqlValue();
188   }
189 
HasMemoizedValue(MemoizedArgs args)190   bool HasMemoizedValue(MemoizedArgs args) {
191     return GetMemoizedValue(args).has_value();
192   }
193 
194   // Saves the return value of the current invocation for memoization.
Memoize(MemoizedArgs args,SqlValue value)195   void Memoize(MemoizedArgs args, SqlValue value) {
196     if (!enabled_) {
197       return;
198     }
199     memoized_values_.Insert(args, StoredSqlValue(value));
200   }
201 
202   // Checks that the function has a single int argument and returns it.
AsMemoizedArgs(size_t argc,sqlite3_value ** argv)203   static std::optional<MemoizedArgs> AsMemoizedArgs(size_t argc,
204                                                     sqlite3_value** argv) {
205     if (argc != 1) {
206       return std::nullopt;
207     }
208     SqlValue arg = sqlite::utils::SqliteValueToSqlValue(argv[0]);
209     if (arg.type != SqlValue::Type::kLong) {
210       return std::nullopt;
211     }
212     return arg.AsLong();
213   }
214 
enabled() const215   bool enabled() const { return enabled_; }
216 
217  private:
218   bool enabled_ = false;
219   base::FlatHashMap<MemoizedArgs, StoredSqlValue> memoized_values_;
220 };
221 
222 // A helper to unroll recursive calls: to minimise the amount of stack space
223 // used, memoized recursive calls are evaluated using an on-heap queue.
224 //
225 // We compute the function in two passes:
226 // - In the first pass, we evaluate the statement to discover which recursive
227 //   calls it makes, returning null from recursive calls and ignoring the
228 //   result.
229 // - In the second pass, we evaluate the statement again, but this time we
230 //   memoize the result of each recursive call.
231 //
232 // We maintain a queue for scheduled "first pass" calls and a stack for the
233 // scheduled "second pass" calls, evaluating available first pass calls, then
234 // second pass calls. When we evaluate a first pass call, the further calls to
235 // CreatedFunction::Run will just add it to the "first pass" queue. The second
236 // pass, however, will evaluate the function normally, typically just using the
237 // memoized result for the dependent calls. However, if the recursive calls
238 // depend on the return value of the function, we will proceed with normal
239 // recursion.
240 //
241 // To make it more concrete, consider an following example.
242 // We have a function computing factorial (f) and we want to compute f(3).
243 //
244 // SELECT create_function('f(x INT)', 'INT',
245 // 'SELECT IIF($x = 0, 1, $x * f($x - 1))');
246 // SELECT experimental_memoize('f');
247 // SELECT f(3);
248 //
249 // - We start with a call to f(3). It executes the statement as normal, which
250 //   recursively calls f(2).
251 // - When f(2) is called, we detect that it is a recursive call and we start
252 //   unrolling it, entering RecursiveCallUnroller::Run.
253 // - We schedule first pass for 2 and the state of the unroller
254 //   is first_pass: [2], second_pass: [].
255 // - Then we compute the first pass for f(2). It calls f(1), which is ignored
256 //   due to OnFunctionCall returning kIgnoreDueToFirstPass and 1 is added to the
257 //   first pass queue. 2 is taked out of the first pass queue and moved to the
258 //   second pass stack. State: first_pass: [1], second_pass: [2].
259 // - Then we compute the first pass for 1. The similar thing happens: f(0) is
260 //   called and ignored, 0 is added to first_pass, 1 is added to second_pass.
261 //   State: first_pass: [0], second_pass: [2, 1].
262 // - Then we compute the first pass for 0. It doesn't make further calls, so
263 //   0 is moved to the second pass stack.
264 //   State: first_pass: [], second_pass: [2, 1, 0].
265 // - Then we compute the second pass for 0. It just returns 1.
266 //   State: first_pass: [], second_pass: [2, 1], results: {0: 1}.
267 // - Then we compute the second pass for 1. It calls f(0), which is memoized.
268 //   State: first_pass: [], second_pass: [2], results: {0: 1, 1: 1}.
269 // - Then we compute the second pass for 1. It calls f(1), which is memoized.
270 //   State: first_pass: [], second_pass: [], results: {0: 1, 1: 1, 2: 2}.
271 // - As both first_pass and second_pass are empty, we return from
272 //   RecursiveCallUnroller::Run.
273 // - Control is returned to CreatedFunction::Run for f(2), which returns
274 //   memoized value.
275 // - Then control is returned to CreatedFunction::Run for f(3), which completes
276 //   the computation.
277 class RecursiveCallUnroller {
278  public:
RecursiveCallUnroller(PerfettoSqlEngine * engine,sqlite3_stmt * stmt,const FunctionPrototype & prototype,Memoizer & memoizer)279   RecursiveCallUnroller(PerfettoSqlEngine* engine,
280                         sqlite3_stmt* stmt,
281                         const FunctionPrototype& prototype,
282                         Memoizer& memoizer)
283       : engine_(engine),
284         stmt_(stmt),
285         prototype_(prototype),
286         memoizer_(memoizer) {}
287 
288   // Whether we should just return null due to us being in the "first pass".
289   enum class FunctionCallState {
290     kIgnoreDueToFirstPass,
291     kEvaluate,
292   };
293 
OnFunctionCall(Memoizer::MemoizedArgs args)294   base::StatusOr<FunctionCallState> OnFunctionCall(
295       Memoizer::MemoizedArgs args) {
296     // If we are in the second pass, we just continue the function execution,
297     // including checking if a memoized value is available and returning it.
298     //
299     // We generally expect a memoized value to be available, but there are
300     // cases when it might not be the case, e.g. when which recursive calls are
301     // made depends on the return value of the function, e.g. for the following
302     // function, the first pass will not detect f(y) calls, so they will
303     // be computed recursively.
304     // f(x): SELECT max(f(y)) FROM y WHERE y < f($x - 1);
305     if (state_ == State::kComputingSecondPass) {
306       return FunctionCallState::kEvaluate;
307     }
308     if (!memoizer_.HasMemoizedValue(args)) {
309       ArgState* state = visited_.Find(args);
310       if (state) {
311         // Detect recursive loops, e.g. f(1) calling f(2) calling f(1).
312         if (*state == ArgState::kEvaluating) {
313           return base::ErrStatus("Infinite recursion detected");
314         }
315       } else {
316         visited_.Insert(args, ArgState::kScheduled);
317         first_pass_.push(args);
318       }
319     }
320     return FunctionCallState::kIgnoreDueToFirstPass;
321   }
322 
Run(Memoizer::MemoizedArgs initial_args)323   base::Status Run(Memoizer::MemoizedArgs initial_args) {
324     PERFETTO_TP_TRACE(metatrace::Category::FUNCTION_CALL,
325                       "UNROLL_RECURSIVE_FUNCTION_CALL",
326                       [&](metatrace::Record* r) {
327                         r->AddArg("Function", prototype_.function_name);
328                         r->AddArg("Arg 0", std::to_string(initial_args));
329                       });
330 
331     first_pass_.push(initial_args);
332     visited_.Insert(initial_args, ArgState::kScheduled);
333 
334     while (!first_pass_.empty() || !second_pass_.empty()) {
335       // If we have scheduled first pass calls, we evaluate them first.
336       if (!first_pass_.empty()) {
337         state_ = State::kComputingFirstPass;
338         Memoizer::MemoizedArgs args = first_pass_.front();
339 
340         PERFETTO_TP_TRACE(metatrace::Category::FUNCTION_CALL,
341                           "SQL_FUNCTION_CALL", [&](metatrace::Record* r) {
342                             r->AddArg("Function", prototype_.function_name);
343                             r->AddArg("Type", "UnrollRecursiveCall_FirstPass");
344                             r->AddArg("Arg 0", std::to_string(args));
345                           });
346 
347         first_pass_.pop();
348         second_pass_.push(args);
349         Evaluate(args).status();
350         continue;
351       }
352 
353       state_ = State::kComputingSecondPass;
354       Memoizer::MemoizedArgs args = second_pass_.top();
355 
356       PERFETTO_TP_TRACE(metatrace::Category::FUNCTION_CALL, "SQL_FUNCTION_CALL",
357                         [&](metatrace::Record* r) {
358                           r->AddArg("Function", prototype_.function_name);
359                           r->AddArg("Type", "UnrollRecursiveCall_SecondPass");
360                           r->AddArg("Arg 0", std::to_string(args));
361                         });
362 
363       visited_.Insert(args, ArgState::kEvaluating);
364       second_pass_.pop();
365       base::StatusOr<std::optional<int64_t>> result = Evaluate(args);
366       RETURN_IF_ERROR(result.status());
367       std::optional<int64_t> maybe_int_result = result.value();
368       if (!maybe_int_result.has_value()) {
369         continue;
370       }
371       visited_.Insert(args, ArgState::kEvaluated);
372       memoizer_.Memoize(args, SqlValue::Long(*maybe_int_result));
373     }
374     return base::OkStatus();
375   }
376 
377  private:
378   // This function returns:
379   // - base::ErrStatus if the evaluation of the function failed.
380   // - std::nullopt if the function returned a non-integer value.
381   // - the result of the function otherwise.
Evaluate(Memoizer::MemoizedArgs args)382   base::StatusOr<std::optional<int64_t>> Evaluate(Memoizer::MemoizedArgs args) {
383     RETURN_IF_ERROR(MaybeBindIntArgument(stmt_, prototype_.function_name,
384                                          prototype_.arguments[0], args));
385     base::StatusOr<SqlValue> result = EvaluateScalarStatement(
386         stmt_, engine_->sqlite_engine()->db(), prototype_);
387     sqlite3_reset(stmt_);
388     sqlite3_clear_bindings(stmt_);
389     RETURN_IF_ERROR(result.status());
390     if (result->type != SqlValue::Type::kLong) {
391       return std::optional<int64_t>(std::nullopt);
392     }
393     return std::optional<int64_t>(result->long_value);
394   }
395 
396   PerfettoSqlEngine* engine_;
397   sqlite3_stmt* stmt_;
398   const FunctionPrototype& prototype_;
399   Memoizer& memoizer_;
400 
401   // Current state of the evaluation.
402   enum class State {
403     kComputingFirstPass,
404     kComputingSecondPass,
405   };
406   State state_ = State::kComputingFirstPass;
407 
408   // A state of evaluation of a given argument.
409   enum class ArgState {
410     kScheduled,
411     kEvaluating,
412     kEvaluated,
413   };
414 
415   // See the class-level comment for the explanation of the two passes.
416   std::queue<Memoizer::MemoizedArgs> first_pass_;
417   base::FlatHashMap<Memoizer::MemoizedArgs, ArgState> visited_;
418   std::stack<Memoizer::MemoizedArgs> second_pass_;
419 };
420 
421 }  // namespace
422 
423 // This class is used to store the state of a CREATE_FUNCTION call.
424 // It is used to store the state of the function across multiple invocations
425 // of the function (e.g. when the function is called recursively).
426 class State : public CreatedFunction::Context {
427  public:
State(PerfettoSqlEngine * engine)428   explicit State(PerfettoSqlEngine* engine) : engine_(engine) {}
429   ~State() override;
430 
431   // Prepare a statement and push it into the stack of allocated statements
432   // for this function.
PrepareStatement()433   base::Status PrepareStatement() {
434     SqliteEngine::PreparedStatement stmt =
435         engine_->sqlite_engine()->PrepareStatement(*sql_);
436     RETURN_IF_ERROR(stmt.status());
437     is_valid_ = true;
438     stmts_.push_back(std::move(stmt));
439     return base::OkStatus();
440   }
441 
442   // Sets the state of the function. Should be called only when the function
443   // is invalid (i.e. when it is first created or when the previous statement
444   // failed to prepare).
Reset(FunctionPrototype prototype,sql_argument::Type return_type,SqlSource sql)445   void Reset(FunctionPrototype prototype,
446              sql_argument::Type return_type,
447              SqlSource sql) {
448     // Re-registration of valid functions is not allowed.
449     PERFETTO_DCHECK(!is_valid_);
450     PERFETTO_DCHECK(stmts_.empty());
451 
452     prototype_ = std::move(prototype);
453     return_type_ = return_type;
454     sql_ = std::move(sql);
455   }
456 
457   // This function is called each time the function is called.
458   // It ensures that we have a statement for the current recursion level,
459   // allocating a new one if needed.
PushStackEntry()460   base::Status PushStackEntry() {
461     ++current_recursion_level_;
462     if (current_recursion_level_ > stmts_.size()) {
463       return PrepareStatement();
464     }
465     return base::OkStatus();
466   }
467 
468   // Returns the statement that is used for the current invocation.
CurrentStatement()469   sqlite3_stmt* CurrentStatement() {
470     return stmts_[current_recursion_level_ - 1].sqlite_stmt();
471   }
472 
473   // This function is called each time the function returns and resets the
474   // statement that this invocation used.
PopStackEntry()475   void PopStackEntry() {
476     if (current_recursion_level_ > stmts_.size()) {
477       // This is possible if we didn't prepare the statement and returned
478       // an error.
479       return;
480     }
481     sqlite3_reset(CurrentStatement());
482     sqlite3_clear_bindings(CurrentStatement());
483     --current_recursion_level_;
484   }
485 
OnFunctionCall(Memoizer::MemoizedArgs args)486   base::StatusOr<RecursiveCallUnroller::FunctionCallState> OnFunctionCall(
487       Memoizer::MemoizedArgs args) {
488     if (!recursive_call_unroller_) {
489       return RecursiveCallUnroller::FunctionCallState::kEvaluate;
490     }
491     return recursive_call_unroller_->OnFunctionCall(args);
492   }
493 
494   // Called before checking the function for memoization.
UnrollRecursiveCallIfNeeded(Memoizer::MemoizedArgs args)495   base::Status UnrollRecursiveCallIfNeeded(Memoizer::MemoizedArgs args) {
496     if (!memoizer_.enabled() || !is_in_recursive_call() ||
497         recursive_call_unroller_) {
498       return base::OkStatus();
499     }
500     // If we are in a recursive call, we need to check if we have already
501     // computed the result for the current arguments.
502     if (memoizer_.HasMemoizedValue(args)) {
503       return base::OkStatus();
504     }
505 
506     // If we are in a beginning of a function call:
507     // - is a recursive,
508     // - can be memoized,
509     // - hasn't been memoized already, and
510     // - hasn't start unrolling yet;
511     // start the unrolling and run the unrolling loop.
512     recursive_call_unroller_ = std::make_unique<RecursiveCallUnroller>(
513         engine_, CurrentStatement(), prototype_, memoizer_);
514     auto status = recursive_call_unroller_->Run(args);
515     recursive_call_unroller_.reset();
516     return status;
517   }
518 
519   // Schedule a statement to be validated that it is indeed doesn't have any
520   // more rows.
ScheduleEmptyStatementValidation(sqlite3_stmt * stmt)521   void ScheduleEmptyStatementValidation(sqlite3_stmt* stmt) {
522     empty_stmts_to_validate_.push_back(stmt);
523   }
524 
ValidateEmptyStatements()525   base::Status ValidateEmptyStatements() {
526     while (!empty_stmts_to_validate_.empty()) {
527       sqlite3_stmt* stmt = empty_stmts_to_validate_.back();
528       empty_stmts_to_validate_.pop_back();
529       RETURN_IF_ERROR(
530           CheckNoMoreRows(stmt, engine_->sqlite_engine()->db(), prototype_));
531     }
532     return base::OkStatus();
533   }
534 
is_in_recursive_call() const535   bool is_in_recursive_call() const { return current_recursion_level_ > 1; }
536 
EnableMemoization()537   base::Status EnableMemoization() {
538     return memoizer_.EnableMemoization(prototype_);
539   }
540 
engine() const541   PerfettoSqlEngine* engine() const { return engine_; }
542 
prototype() const543   const FunctionPrototype& prototype() const { return prototype_; }
544 
return_type() const545   sql_argument::Type return_type() const { return return_type_; }
546 
sql() const547   const std::string& sql() const { return sql_->sql(); }
548 
is_valid() const549   bool is_valid() const { return is_valid_; }
550 
memoizer()551   Memoizer& memoizer() { return memoizer_; }
552 
553  private:
554   PerfettoSqlEngine* engine_;
555   FunctionPrototype prototype_;
556   sql_argument::Type return_type_;
557   std::optional<SqlSource> sql_;
558   // Perfetto SQL functions support recursion. Given that each function call in
559   // the stack requires a dedicated statement, we maintain a stack of prepared
560   // statements and use the top one for each new call (allocating a new one if
561   // needed).
562   std::vector<SqliteEngine::PreparedStatement> stmts_;
563   // A list of statements to verify to ensure that they don't have more rows
564   // in VerifyPostConditions.
565   std::vector<sqlite3_stmt*> empty_stmts_to_validate_;
566   size_t current_recursion_level_ = 0;
567   // Function re-registration is not allowed, but the user is allowed to define
568   // the function again if the first call failed. |is_valid_| flag helps that
569   // by tracking whether the current function definition is valid (in which case
570   // re-registration is not allowed).
571   bool is_valid_ = false;
572   Memoizer memoizer_;
573   // Set if we are in a middle of unrolling a recursive call.
574   std::unique_ptr<RecursiveCallUnroller> recursive_call_unroller_;
575 };
576 
577 State::~State() = default;
578 
MakeContext(PerfettoSqlEngine * engine)579 std::unique_ptr<CreatedFunction::Context> CreatedFunction::MakeContext(
580     PerfettoSqlEngine* engine) {
581   return std::make_unique<State>(engine);
582 }
583 
IsValid(Context * ctx)584 bool CreatedFunction::IsValid(Context* ctx) {
585   return static_cast<State*>(ctx)->is_valid();
586 }
587 
Reset(Context * ctx,PerfettoSqlEngine * engine)588 void CreatedFunction::Reset(Context* ctx, PerfettoSqlEngine* engine) {
589   ctx->~Context();
590   new (ctx) State(engine);
591 }
592 
Run(CreatedFunction::Context * ctx,size_t argc,sqlite3_value ** argv,SqlValue & out,Destructors &)593 base::Status CreatedFunction::Run(CreatedFunction::Context* ctx,
594                                   size_t argc,
595                                   sqlite3_value** argv,
596                                   SqlValue& out,
597                                   Destructors&) {
598   State* state = static_cast<State*>(ctx);
599 
600   // Enter the function and ensure that we have a statement allocated.
601   RETURN_IF_ERROR(state->PushStackEntry());
602 
603   if (argc != state->prototype().arguments.size()) {
604     return base::ErrStatus(
605         "%s: invalid number of args; expected %zu, received %zu",
606         state->prototype().function_name.c_str(),
607         state->prototype().arguments.size(), argc);
608   }
609 
610   // Type check all the arguments.
611   for (size_t i = 0; i < argc; ++i) {
612     sqlite3_value* arg = argv[i];
613     sql_argument::Type type = state->prototype().arguments[i].type();
614     base::Status status = sqlite::utils::TypeCheckSqliteValue(
615         arg, sql_argument::TypeToSqlValueType(type),
616         sql_argument::TypeToHumanFriendlyString(type));
617     if (!status.ok()) {
618       return base::ErrStatus("%s[arg=%s]: argument %zu %s",
619                              state->prototype().function_name.c_str(),
620                              sqlite3_value_text(arg), i, status.c_message());
621     }
622   }
623 
624   std::optional<Memoizer::MemoizedArgs> memoized_args =
625       Memoizer::AsMemoizedArgs(argc, argv);
626 
627   if (memoized_args) {
628     // If we are in the middle of an recursive calls unrolling, we might want to
629     // ignore the function invocation. See the comment in RecursiveCallUnroller
630     // for more details.
631     base::StatusOr<RecursiveCallUnroller::FunctionCallState> unroll_state =
632         state->OnFunctionCall(*memoized_args);
633     RETURN_IF_ERROR(unroll_state.status());
634     if (*unroll_state ==
635         RecursiveCallUnroller::FunctionCallState::kIgnoreDueToFirstPass) {
636       // Return NULL.
637       return base::OkStatus();
638     }
639 
640     RETURN_IF_ERROR(state->UnrollRecursiveCallIfNeeded(*memoized_args));
641 
642     std::optional<SqlValue> memoized_value =
643         state->memoizer().GetMemoizedValue(*memoized_args);
644     if (memoized_value) {
645       out = *memoized_value;
646       return base::OkStatus();
647     }
648   }
649 
650   PERFETTO_TP_TRACE(
651       metatrace::Category::FUNCTION_CALL, "SQL_FUNCTION_CALL",
652       [state, argv](metatrace::Record* r) {
653         r->AddArg("Function", state->prototype().function_name.c_str());
654         for (uint32_t i = 0; i < state->prototype().arguments.size(); ++i) {
655           std::string key = "Arg " + std::to_string(i);
656           const char* value =
657               reinterpret_cast<const char*>(sqlite3_value_text(argv[i]));
658           r->AddArg(base::StringView(key),
659                     value ? base::StringView(value) : base::StringView("NULL"));
660         }
661       });
662 
663   RETURN_IF_ERROR(
664       BindArguments(state->CurrentStatement(), state->prototype(), argc, argv));
665   auto result = EvaluateScalarStatement(state->CurrentStatement(),
666                                         state->engine()->sqlite_engine()->db(),
667                                         state->prototype());
668   RETURN_IF_ERROR(result.status());
669   out = result.value();
670   state->ScheduleEmptyStatementValidation(state->CurrentStatement());
671 
672   if (memoized_args) {
673     state->memoizer().Memoize(*memoized_args, out);
674   }
675 
676   return base::OkStatus();
677 }
678 
Cleanup(CreatedFunction::Context * ctx)679 void CreatedFunction::Cleanup(CreatedFunction::Context* ctx) {
680   // Clear the statement.
681   static_cast<State*>(ctx)->PopStackEntry();
682 }
683 
VerifyPostConditions(CreatedFunction::Context * ctx)684 base::Status CreatedFunction::VerifyPostConditions(
685     CreatedFunction::Context* ctx) {
686   return static_cast<State*>(ctx)->ValidateEmptyStatements();
687 }
688 
Prepare(CreatedFunction::Context * ctx,FunctionPrototype prototype,sql_argument::Type return_type,SqlSource source)689 base::Status CreatedFunction::Prepare(CreatedFunction::Context* ctx,
690                                       FunctionPrototype prototype,
691                                       sql_argument::Type return_type,
692                                       SqlSource source) {
693   State* state = static_cast<State*>(ctx);
694   state->Reset(std::move(prototype), return_type, std::move(source));
695 
696   // Ideally, we would unregister the function here if the statement prep
697   // failed, but SQLite doesn't allow unregistering functions inside active
698   // statements. So instead we'll just try to prepare the statement when calling
699   // this function, which will return an error.
700   return state->PrepareStatement();
701 }
702 
EnableMemoization(Context * ctx)703 base::Status CreatedFunction::EnableMemoization(Context* ctx) {
704   return static_cast<State*>(ctx)->EnableMemoization();
705 }
706 
707 }  // namespace trace_processor
708 }  // namespace perfetto
709