• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# Copyright (C) 2025 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import sys
17from pathlib import Path
18import argparse
19import re
20from typing import Dict, List, Tuple, Union
21import typing as t
22
23# Add sqlglot from buildtools to the Python path
24ROOT_DIR = Path(__file__).parent.parent
25SQLGLOT_DIR = ROOT_DIR / 'buildtools' / 'sqlglot'
26sys.path.append(str(SQLGLOT_DIR))
27
28import sqlglot
29from sqlglot.dialects.dialect import rename_func
30from sqlglot import exp
31from sqlglot.dialects.sqlite import SQLite
32from sqlglot.tokens import TokenType
33
34
35class Perfetto(SQLite):
36  """Perfetto SQL dialect implementation."""
37
38  class Generator(SQLite.Generator):
39    """Generator for Perfetto SQL dialect."""
40
41    CREATABLE_KIND_MAPPING = {
42        "PERFETTO INDEX": "INDEX",
43    }
44
45    TYPE_MAPPING = {
46        **SQLite.Generator.TYPE_MAPPING,
47        exp.DataType.Type.BOOLEAN: "BOOL",
48        exp.DataType.Type.BIGINT: "LONG",
49        exp.DataType.Type.TEXT: "STRING",
50        exp.DataType.Type.DOUBLE: "DOUBLE",
51    }
52
53    TRANSFORMS = {
54        **SQLite.Generator.TRANSFORMS,
55        exp.NEQ:
56            lambda self, e: self.binary(e, "!="),
57        exp.Substring:
58            rename_func("SUBSTR"),
59        exp.ReturnsProperty:
60            lambda self, e: self.return_property(e),
61    }
62
63    PROPERTIES_LOCATION = {
64        **SQLite.Generator.PROPERTIES_LOCATION,
65        exp.ReturnsProperty:
66            exp.Properties.Location.POST_SCHEMA,
67    }
68
69    SUPPORTS_TABLE_ALIAS_COLUMNS = True
70
71    JSON_KEY_VALUE_PAIR_SEP = ","
72
73    def maybe_comment(self,
74                      sql: str,
75                      expression: t.Optional[exp.Expression] = None,
76                      comments: t.Optional[t.List[str]] = None,
77                      separated: bool = False):
78      comments = (((expression and expression.comments) if comments is None else
79                   comments) if self.comments else None)
80      if not comments or isinstance(expression, exp.Connector):
81        return sql
82
83      comments_sql = "\n".join(f"--{comment.rstrip()}" for comment in comments)
84      if not comments_sql:
85        return sql
86      sep = ('\n' if sql and sql[0].isspace() else '') + comments_sql
87      return f"{sep}{self.sep()}{sql.strip()}"
88
89    def expressions(
90        self,
91        expression: t.Optional[exp.Expression] = None,
92        key: t.Optional[str] = None,
93        sqls: t.Optional[t.Collection[t.Union[str, exp.Expression]]] = None,
94        flat: bool = False,
95        indent: bool = True,
96        skip_first: bool = False,
97        skip_last: bool = False,
98        sep: str = ", ",
99        prefix: str = "",
100        dynamic: bool = False,
101        new_line: bool = False,
102    ) -> str:
103      expressions = expression.args.get(key or
104                                        "expressions") if expression else sqls
105
106      if not expressions:
107        return ""
108
109      if flat:
110        return sep.join(
111            sql for sql in (self.sql(e) for e in expressions) if sql)
112
113      num_sqls = len(expressions)
114      result_sqls = []
115
116      for i, e in enumerate(expressions):
117        sql = self.sql(e, comment=False)
118        if not sql:
119          continue
120
121        comments = self.maybe_comment("", e) if isinstance(
122            e, exp.Expression) else ""
123
124        if self.pretty:
125          if self.leading_comma:
126            result_sqls.append(f"{sep if i > 0 else ''}{prefix}{comments}{sql}")
127          else:
128            result_sqls.append(
129                f"{prefix}{comments}{sql}{(sep.rstrip() if comments else sep) if i + 1 < num_sqls else ''}"
130            )
131        else:
132          result_sqls.append(
133              f"{prefix}{comments}{sql}{sep if i + 1 < num_sqls else ''}")
134
135      if self.pretty and (not dynamic or self.too_wide(result_sqls)):
136        if new_line:
137          result_sqls.insert(0, "")
138          result_sqls.append("")
139        result_sql = "\n".join(s.rstrip() for s in result_sqls)
140      else:
141        result_sql = "".join(result_sqls)
142
143      return (self.indent(
144          result_sql, skip_first=skip_first, skip_last=skip_last)
145              if indent else result_sql)
146
147    def not_sql(self, expression: exp.Not) -> str:
148      if isinstance(expression.this, exp.Is) and isinstance(
149          expression.this.right, exp.Null):
150        return f"{self.sql(expression.this.left)} IS NOT NULL"
151      return super().not_sql(expression)
152
153    def connector_sql(
154        self,
155        expression: exp.Connector,
156        op: str,
157        stack: t.Optional[t.List[t.Union[str, exp.Expression]]] = None,
158    ) -> str:
159      if stack is not None:
160        if expression.expressions:
161          stack.append(self.expressions(expression, sep=f" {op} "))
162        else:
163          if expression.comments and self.comments:
164            comments = []
165            for comment in expression.comments:
166              if comment:
167                comments.append(f"--{self.pad_comment(comment).rstrip()}")
168            op = "\n".join(comments) + "\n" + op
169          stack.extend((expression.right, op, expression.left))
170        return op
171      return super().connector_sql(expression, op, stack)
172
173    def case_sql(self, expression: exp.Case) -> str:
174      this = self.sql(expression, "this")
175      statements = [f"CASE {this}" if this else "CASE"]
176
177      for e in expression.args["ifs"]:
178        statements.append(self.maybe_comment(f"WHEN {self.sql(e, 'this')}", e))
179        statements.append(f"THEN {self.sql(e, 'true')}")
180
181      default = self.sql(expression, "default")
182
183      if default:
184        statements.append(f"ELSE {default}")
185
186      statements.append("END")
187
188      if self.pretty and self.too_wide(statements):
189        return self.indent(
190            "\n".join(statements), skip_first=True, skip_last=True)
191
192      return " ".join(statements)
193
194    def with_sql(self, expression: exp.With) -> str:
195      sql = self.expressions(expression)
196      recursive = ("RECURSIVE " if self.CTE_RECURSIVE_KEYWORD_REQUIRED and
197                   expression.args.get("recursive") else "")
198      return f"WITH\n{recursive}{sql}"
199
200    def return_property(self, expression: exp.ReturnsProperty) -> str:
201      return f"RETURNS {self.sql(expression, 'this')}"
202
203    # https://www.sqlite.org/lang_aggfunc.html#group_concat
204    def groupconcat_sql(self, expression: exp.GroupConcat) -> str:
205      this = expression.this
206      distinct = expression.find(exp.Distinct)
207
208      if distinct:
209        this = distinct.expressions[0]
210        distinct_sql = "DISTINCT "
211      else:
212        distinct_sql = ""
213
214      separator = expression.args.get("separator")
215      return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})"
216
217  class Parser(SQLite.Parser):
218    STATEMENT_PARSERS = {
219        **SQLite.Parser.STATEMENT_PARSERS,
220        TokenType.CREATE:
221            lambda self: self._parse_create_override(),
222        TokenType.VAR:
223            lambda self: self._parse_var_override(),
224    }
225
226    def _parse_create_override(self):
227      if self._match_text_seq("VIRTUAL"):
228        return self._parse_create_virtual_table_override()
229
230      if not self._match_text_seq("PERFETTO"):
231        return self._parse_create()
232
233      is_table = self._match_text_seq("TABLE")
234      is_view = self._match_text_seq("VIEW")
235      is_function = self._match_text_seq("FUNCTION")
236      is_index = self._match_text_seq("INDEX")
237      is_macro = self._match_text_seq("MACRO")
238      if not is_table and not is_view and not is_function and not is_index and not is_macro:
239        return self.raise_error(
240            "Expected 'TABLE', 'VIEW', 'FUNCTION', 'INDEX' or 'MACRO'")
241
242      if is_index:
243        # Parse index name
244        name = self._parse_id_var()
245
246        # Parse ON
247        if not self._match_text_seq("ON"):
248          return self.raise_error("Expected 'ON'")
249
250        # Parse table name
251        table = self._parse_table()
252
253        return exp.Create(
254            this=self.expression(
255                exp.Index,
256                this=name,
257                table=table,
258            ),
259            kind='PERFETTO INDEX',
260        )
261
262      if is_function or is_macro:
263        # Parse function name
264        udf = self._parse_user_defined_function()
265
266        # Parse RETURNS type
267        if not self._match_text_seq("RETURNS"):
268          return self.raise_error("Expected 'RETURNS'")
269        return_comments = self._prev_comments
270        return_type = self._parse_returns()
271        return_type.comments = return_comments
272
273        # Parse AS
274        if not self._match(TokenType.ALIAS):
275          return self.raise_error("Expected 'AS'")
276
277        if is_function:
278          # Parse function body
279          body = self._parse_select()
280        else:
281          if str(return_type.this).upper() == "_PROJECTIONFRAGMENT":
282            body = self._parse_projections()
283          else:
284            body = self._parse_expression()
285
286        return exp.Create(
287            this=udf,
288            kind="PERFETTO FUNCTION" if is_function else "PERFETTO MACRO",
289            expression=body,
290            properties=self.expression(
291                exp.Properties,
292                expressions=[
293                    return_type,
294                ],
295            ),
296        )
297
298      # Parse view/table name
299      table = self._parse_table(schema=True)
300
301      # Parse AS
302      if not self._match(TokenType.ALIAS):
303        return self.raise_error("Expected 'AS'")
304
305      # Parse SELECT statement
306      select = self._parse_select()
307
308      return exp.Create(
309          this=table,
310          kind='PERFETTO VIEW' if is_view else 'PERFETTO TABLE',
311          expression=select)
312
313    def _parse_var_override(self):
314      assert self._prev
315      if self._prev.text.upper() != "INCLUDE":
316        return self.raise_error("Expected 'INCLUDE'")
317
318      # Expect 'PERFETTO'
319      if not self._match_text_seq('PERFETTO'):
320        return self.raise_error("Expected 'PERFETTO'")
321
322      # Expect 'MODULE'
323      if not self._match_text_seq('MODULE'):
324        return self.raise_error("Expected 'MODULE'")
325
326      # Parse the module path (e.g. android.suspend)
327      module_path = []
328      while True:
329        id = self._parse_id_var()
330        if not id:
331          break
332        module_path.append(id.text(key="this"))
333        if not self._match(TokenType.DOT):
334          break
335      return exp.Command(this="INCLUDE PERFETTO MODULE " +
336                         '.'.join(module_path))
337
338    def _parse_create_virtual_table_override(self):
339      if not self._match(TokenType.TABLE):
340        return self.raise_error("Expected 'TABLE'")
341
342      name = self._parse_id_var()
343
344      if not self._match_text_seq('USING'):
345        return self.raise_error("Expected 'USING'")
346
347      sp = self._parse_id_var()
348      self._match_l_paren()
349      start = self._prev
350
351      while not self._match(TokenType.R_PAREN):
352        self._advance()
353
354      return exp.Command(
355          this="CREATE VIRTUAL TABLE " + name.text(key="this") + " USING " +
356          sp.text(key="this") + " " + self._find_sql(start, self._prev),)
357
358    def _parse_case(self) -> t.Optional[exp.Expression]:
359      ifs = []
360      default = None
361
362      comments = self._prev_comments
363      expression = self._parse_assignment()
364
365      while self._match(TokenType.WHEN):
366        when_comments = self._prev_comments
367        this = self._parse_assignment()
368        self._match(TokenType.THEN)
369        then = self._parse_assignment()
370        ifs.append(
371            self.expression(
372                exp.If, this=this, true=then, comments=when_comments))
373
374      if self._match(TokenType.ELSE):
375        default = self._parse_assignment()
376
377      if not self._match(TokenType.END):
378        if isinstance(default,
379                      exp.Interval) and default.this.sql().upper() == "END":
380          default = exp.column("interval")
381        else:
382          self.raise_error("Expected END after CASE", self._prev)
383
384      return self.expression(
385          exp.Case,
386          comments=comments,
387          this=expression,
388          ifs=ifs,
389          default=default)
390
391    def _parse_is(
392        self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
393      index = self._index - 1
394      negate = self._match(TokenType.NOT)
395
396      if self._match_text_seq("DISTINCT", "FROM"):
397        klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ
398        return self.expression(
399            klass, this=this, expression=self._parse_bitwise())
400
401      if self._match(TokenType.JSON):
402        kind = self._match_texts(
403            self.IS_JSON_PREDICATE_KIND) and self._prev.text.upper()
404
405        if self._match_text_seq("WITH"):
406          _with = True
407        elif self._match_text_seq("WITHOUT"):
408          _with = False
409        else:
410          _with = None
411
412        unique = self._match(TokenType.UNIQUE)
413        self._match_text_seq("KEYS")
414        expression: t.Optional[exp.Expression] = self.expression(
415            exp.JSON, **{
416                "this": kind,
417                "with": _with,
418                "unique": unique
419            })
420      else:
421        expression = self._parse_expression()
422        if not expression:
423          self._retreat(index)
424          return None
425
426      this = self.expression(exp.Is, this=this, expression=expression)
427      return self.expression(exp.Not, this=this) if negate else this
428
429
430def preprocess_macros(sql: str) -> Tuple[str, List[Tuple[str, str]]]:
431  """Convert macro calls to placeholders for sqlglot parsing.
432
433  Args:
434    sql: Input SQL string
435
436  Returns:
437    Tuple of (processed SQL, list of (placeholder, original macro text) pairs)
438  """
439  result = sql
440  macros = []
441
442  for match in re.finditer(r'(\w+)\s*!\s*\(', sql):
443    start = match.start()
444    i = match.end()
445    paren = 1
446    while i < len(sql) and paren > 0:
447      if sql[i] == '(':
448        paren += 1
449      elif sql[i] == ')':
450        paren -= 1
451      i += 1
452
453    if paren == 0:
454      macro = sql[start:i]
455      placeholder = f'__macro_{len(macros)}__'
456      macros.append((placeholder, macro))
457      result = result.replace(macro, placeholder)
458
459  return result, macros
460
461
462def postprocess_macros(sql: str, macros: List[Tuple[str, str]]) -> str:
463  """Restore macro calls from their placeholders.
464
465  Args:
466    sql: Formatted SQL with placeholders
467    macros: List of (placeholder, original macro text) pairs from preprocess_macros
468
469  Returns:
470    SQL with macros restored
471  """
472  result = sql
473  for placeholder, macro_text in macros:
474    result = result.replace(placeholder, macro_text)
475  return result
476
477
478def extract_comment_blocks(sql: str) -> Tuple[str, Dict[str, str]]:
479  """Extract comment blocks from SQL and replace with placeholders.
480
481  A comment block is defined as one or more comment lines (starting with --)
482  that are surrounded by empty lines or start/end of file.
483
484  Args:
485    sql: Input SQL string
486
487  Returns:
488    A tuple containing:
489      - Processed SQL with comment blocks replaced by placeholders
490      - Dict mapping placeholders to their original comment blocks
491  """
492  # Split into chunks separated by empty lines
493  chunks = []
494  current = []
495  for line in sql.splitlines():
496    if line.strip():
497      current.append(line)
498    elif current:
499      chunks.append(current)
500      current = []
501  if current:
502    chunks.append(current)
503
504  # Process each chunk
505  blocks = {}
506  result = []
507  for i, chunk in enumerate(chunks):
508    # A chunk is a comment block if all lines start with --
509    if all(line.strip().startswith('--') for line in chunk):
510      placeholder = f'-- __COMMENT_BLOCK_{i}__'
511      blocks[placeholder] = '\n'.join(chunk) + '\n'
512      result.append('')
513      result.append(placeholder)
514      result.append('')
515    else:
516      result.append('')
517      result.extend(chunk)
518      result.append('')
519
520  return '\n'.join(result).strip(), blocks
521
522
523def restore_comment_blocks(sql: str, blocks: Dict[str, str]) -> str:
524  """Restore comment blocks from their placeholders.
525
526  Args:
527    sql: SQL string with placeholders
528    blocks: Dict mapping placeholders to original comment blocks
529
530  Returns:
531    SQL string with comment blocks restored in their original positions
532  """
533  result = sql
534  for placeholder, block in blocks.items():
535    result = result.replace(placeholder, block)
536  return result
537
538
539def format_sql(file: Path,
540               sql: str,
541               indent_width: int = 2,
542               verbose: bool = False) -> str:
543  """Format SQL content with consistent style.
544
545  Args:
546    file: Path to the SQL file (for error reporting)
547    sql: SQL content to format
548    indent_width: Number of spaces for indentation
549    verbose: Whether to print status messages
550
551  Returns:
552    Formatted SQL string
553
554  Raises:
555    Exception: If SQL parsing or formatting fails
556  """
557  if sql.find('-- sqlformat file off') != -1:
558    if verbose:
559      print(f"Ignoring {file}", file=sys.stderr)
560    return sql
561
562  # First extract comment blocks
563  sql_with_placeholders, comment_blocks = extract_comment_blocks(sql)
564
565  # Then process macros
566  processed, macros = preprocess_macros(sql_with_placeholders)
567  try:
568    formatted = ''
569    for ast in sqlglot.parse(sql=processed, dialect=Perfetto):
570      formatted += ast.sql(
571          pretty=True,
572          dialect=Perfetto,
573          indent=indent_width,
574          normalize=True,
575          normalize_functions='lower',
576      )
577      formatted += ";\n\n"
578
579    # Restore macros first, then comment blocks
580    with_macros = postprocess_macros(formatted, macros)
581    return restore_comment_blocks(with_macros, comment_blocks).rstrip() + '\n'
582  except Exception as e:
583    print(f"Failed to format SQL: file {file}, {e}", file=sys.stderr)
584    raise e
585
586
587def format_files_in_place(paths: List[Union[str, Path]],
588                          indent_width: int = 2,
589                          verbose: bool = False) -> None:
590  """Format multiple SQL files in place.
591
592  Args:
593      paths: List of file or directory paths to format
594      indent_width: Number of spaces for indentation
595      verbose: Whether to print status messages
596  """
597  for path in paths:
598    path = Path(path)
599    if path.is_dir():
600      # Process all .sql files in directory recursively
601      sql_files = path.rglob('*.sql')
602    else:
603      # Single file
604      sql_files = [path]
605
606    for sql_file in sql_files:
607      with open(sql_file) as f:
608        sql = f.read()
609      formatted = format_sql(sql_file, sql, indent_width, verbose)
610      with open(sql_file, 'w') as f:
611        f.write(formatted)
612      if verbose:
613        print(f"Formatted {sql_file}", file=sys.stderr)
614
615
616def check_sql_formatting(paths: List[Union[str, Path]],
617                         indent_width: int = 2,
618                         verbose: bool = False) -> bool:
619  """Check SQL files for formatting violations without making changes.
620
621  Args:
622      paths: List of file or directory paths to check
623      indent_width: Number of spaces for indentation
624      verbose: Whether to print status messages
625
626  Returns:
627      True if all files are properly formatted, False otherwise
628  """
629  all_formatted = True
630  for path in paths:
631    path = Path(path)
632    if path.is_dir():
633      sql_files = path.rglob('*.sql')
634    else:
635      sql_files = [path]
636
637    for sql_file in sql_files:
638      with open(sql_file) as f:
639        sql = f.read()
640      formatted = format_sql(sql_file, sql, indent_width, verbose)
641      if formatted != sql:
642        print(f"Would format {sql_file}", file=sys.stderr)
643        all_formatted = False
644
645  return all_formatted
646
647
648def main() -> None:
649  """Main entry point."""
650  if not SQLGLOT_DIR.exists():
651    print(f"{SQLGLOT_DIR} does not exist. Run tools/install-build-deps first.")
652    sys.exit(1)
653
654  parser = argparse.ArgumentParser(
655      description='Format SQL queries with consistent style')
656  parser.add_argument(
657      'paths',
658      nargs='*',
659      help='Paths to SQL files or directories containing SQL files')
660  parser.add_argument(
661      '--indent-width',
662      type=int,
663      default=2,
664      help='Number of spaces for indentation (default: 2)')
665  parser.add_argument(
666      '--in-place', action='store_true', help='Format files in place')
667  parser.add_argument(
668      '--check-only',
669      action='store_true',
670      help='Check for formatting violations without making changes')
671  parser.add_argument(
672      '--verbose',
673      action='store_true',
674      help='Print status messages during execution')
675  args = parser.parse_args()
676
677  if args.check_only:
678    properly_formatted = check_sql_formatting(args.paths, args.indent_width,
679                                              args.verbose)
680    sys.exit(0 if properly_formatted else 1)
681  elif args.in_place:
682    format_files_in_place(args.paths, args.indent_width, args.verbose)
683  else:
684    # Read from stdin if no files provided
685    if not sys.stdin.isatty():
686      sql_input = sys.stdin.read()
687      formatted_sql = format_sql(Path("stdin"), sql_input, args.indent_width)
688      print(formatted_sql)
689    else:
690      # Print help if no input provided
691      parser.print_help()
692
693
694if __name__ == '__main__':
695  main()
696