1#!/usr/bin/env python3 2# Copyright (C) 2025 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import sys 17from pathlib import Path 18import argparse 19import re 20from typing import Dict, List, Tuple, Union 21import typing as t 22 23# Add sqlglot from buildtools to the Python path 24ROOT_DIR = Path(__file__).parent.parent 25SQLGLOT_DIR = ROOT_DIR / 'buildtools' / 'sqlglot' 26sys.path.append(str(SQLGLOT_DIR)) 27 28import sqlglot 29from sqlglot.dialects.dialect import rename_func 30from sqlglot import exp 31from sqlglot.dialects.sqlite import SQLite 32from sqlglot.tokens import TokenType 33 34 35class Perfetto(SQLite): 36 """Perfetto SQL dialect implementation.""" 37 38 class Generator(SQLite.Generator): 39 """Generator for Perfetto SQL dialect.""" 40 41 CREATABLE_KIND_MAPPING = { 42 "PERFETTO INDEX": "INDEX", 43 } 44 45 TYPE_MAPPING = { 46 **SQLite.Generator.TYPE_MAPPING, 47 exp.DataType.Type.BOOLEAN: "BOOL", 48 exp.DataType.Type.BIGINT: "LONG", 49 exp.DataType.Type.TEXT: "STRING", 50 exp.DataType.Type.DOUBLE: "DOUBLE", 51 } 52 53 TRANSFORMS = { 54 **SQLite.Generator.TRANSFORMS, 55 exp.NEQ: 56 lambda self, e: self.binary(e, "!="), 57 exp.Substring: 58 rename_func("SUBSTR"), 59 exp.ReturnsProperty: 60 lambda self, e: self.return_property(e), 61 } 62 63 PROPERTIES_LOCATION = { 64 **SQLite.Generator.PROPERTIES_LOCATION, 65 exp.ReturnsProperty: 66 exp.Properties.Location.POST_SCHEMA, 67 } 68 69 SUPPORTS_TABLE_ALIAS_COLUMNS = True 70 71 JSON_KEY_VALUE_PAIR_SEP = "," 72 73 def maybe_comment(self, 74 sql: str, 75 expression: t.Optional[exp.Expression] = None, 76 comments: t.Optional[t.List[str]] = None, 77 separated: bool = False): 78 comments = (((expression and expression.comments) if comments is None else 79 comments) if self.comments else None) 80 if not comments or isinstance(expression, exp.Connector): 81 return sql 82 83 comments_sql = "\n".join(f"--{comment.rstrip()}" for comment in comments) 84 if not comments_sql: 85 return sql 86 sep = ('\n' if sql and sql[0].isspace() else '') + comments_sql 87 return f"{sep}{self.sep()}{sql.strip()}" 88 89 def expressions( 90 self, 91 expression: t.Optional[exp.Expression] = None, 92 key: t.Optional[str] = None, 93 sqls: t.Optional[t.Collection[t.Union[str, exp.Expression]]] = None, 94 flat: bool = False, 95 indent: bool = True, 96 skip_first: bool = False, 97 skip_last: bool = False, 98 sep: str = ", ", 99 prefix: str = "", 100 dynamic: bool = False, 101 new_line: bool = False, 102 ) -> str: 103 expressions = expression.args.get(key or 104 "expressions") if expression else sqls 105 106 if not expressions: 107 return "" 108 109 if flat: 110 return sep.join( 111 sql for sql in (self.sql(e) for e in expressions) if sql) 112 113 num_sqls = len(expressions) 114 result_sqls = [] 115 116 for i, e in enumerate(expressions): 117 sql = self.sql(e, comment=False) 118 if not sql: 119 continue 120 121 comments = self.maybe_comment("", e) if isinstance( 122 e, exp.Expression) else "" 123 124 if self.pretty: 125 if self.leading_comma: 126 result_sqls.append(f"{sep if i > 0 else ''}{prefix}{comments}{sql}") 127 else: 128 result_sqls.append( 129 f"{prefix}{comments}{sql}{(sep.rstrip() if comments else sep) if i + 1 < num_sqls else ''}" 130 ) 131 else: 132 result_sqls.append( 133 f"{prefix}{comments}{sql}{sep if i + 1 < num_sqls else ''}") 134 135 if self.pretty and (not dynamic or self.too_wide(result_sqls)): 136 if new_line: 137 result_sqls.insert(0, "") 138 result_sqls.append("") 139 result_sql = "\n".join(s.rstrip() for s in result_sqls) 140 else: 141 result_sql = "".join(result_sqls) 142 143 return (self.indent( 144 result_sql, skip_first=skip_first, skip_last=skip_last) 145 if indent else result_sql) 146 147 def not_sql(self, expression: exp.Not) -> str: 148 if isinstance(expression.this, exp.Is) and isinstance( 149 expression.this.right, exp.Null): 150 return f"{self.sql(expression.this.left)} IS NOT NULL" 151 return super().not_sql(expression) 152 153 def connector_sql( 154 self, 155 expression: exp.Connector, 156 op: str, 157 stack: t.Optional[t.List[t.Union[str, exp.Expression]]] = None, 158 ) -> str: 159 if stack is not None: 160 if expression.expressions: 161 stack.append(self.expressions(expression, sep=f" {op} ")) 162 else: 163 if expression.comments and self.comments: 164 comments = [] 165 for comment in expression.comments: 166 if comment: 167 comments.append(f"--{self.pad_comment(comment).rstrip()}") 168 op = "\n".join(comments) + "\n" + op 169 stack.extend((expression.right, op, expression.left)) 170 return op 171 return super().connector_sql(expression, op, stack) 172 173 def case_sql(self, expression: exp.Case) -> str: 174 this = self.sql(expression, "this") 175 statements = [f"CASE {this}" if this else "CASE"] 176 177 for e in expression.args["ifs"]: 178 statements.append(self.maybe_comment(f"WHEN {self.sql(e, 'this')}", e)) 179 statements.append(f"THEN {self.sql(e, 'true')}") 180 181 default = self.sql(expression, "default") 182 183 if default: 184 statements.append(f"ELSE {default}") 185 186 statements.append("END") 187 188 if self.pretty and self.too_wide(statements): 189 return self.indent( 190 "\n".join(statements), skip_first=True, skip_last=True) 191 192 return " ".join(statements) 193 194 def with_sql(self, expression: exp.With) -> str: 195 sql = self.expressions(expression) 196 recursive = ("RECURSIVE " if self.CTE_RECURSIVE_KEYWORD_REQUIRED and 197 expression.args.get("recursive") else "") 198 return f"WITH\n{recursive}{sql}" 199 200 def return_property(self, expression: exp.ReturnsProperty) -> str: 201 return f"RETURNS {self.sql(expression, 'this')}" 202 203 # https://www.sqlite.org/lang_aggfunc.html#group_concat 204 def groupconcat_sql(self, expression: exp.GroupConcat) -> str: 205 this = expression.this 206 distinct = expression.find(exp.Distinct) 207 208 if distinct: 209 this = distinct.expressions[0] 210 distinct_sql = "DISTINCT " 211 else: 212 distinct_sql = "" 213 214 separator = expression.args.get("separator") 215 return f"GROUP_CONCAT({distinct_sql}{self.format_args(this, separator)})" 216 217 class Parser(SQLite.Parser): 218 STATEMENT_PARSERS = { 219 **SQLite.Parser.STATEMENT_PARSERS, 220 TokenType.CREATE: 221 lambda self: self._parse_create_override(), 222 TokenType.VAR: 223 lambda self: self._parse_var_override(), 224 } 225 226 def _parse_create_override(self): 227 if self._match_text_seq("VIRTUAL"): 228 return self._parse_create_virtual_table_override() 229 230 if not self._match_text_seq("PERFETTO"): 231 return self._parse_create() 232 233 is_table = self._match_text_seq("TABLE") 234 is_view = self._match_text_seq("VIEW") 235 is_function = self._match_text_seq("FUNCTION") 236 is_index = self._match_text_seq("INDEX") 237 is_macro = self._match_text_seq("MACRO") 238 if not is_table and not is_view and not is_function and not is_index and not is_macro: 239 return self.raise_error( 240 "Expected 'TABLE', 'VIEW', 'FUNCTION', 'INDEX' or 'MACRO'") 241 242 if is_index: 243 # Parse index name 244 name = self._parse_id_var() 245 246 # Parse ON 247 if not self._match_text_seq("ON"): 248 return self.raise_error("Expected 'ON'") 249 250 # Parse table name 251 table = self._parse_table() 252 253 return exp.Create( 254 this=self.expression( 255 exp.Index, 256 this=name, 257 table=table, 258 ), 259 kind='PERFETTO INDEX', 260 ) 261 262 if is_function or is_macro: 263 # Parse function name 264 udf = self._parse_user_defined_function() 265 266 # Parse RETURNS type 267 if not self._match_text_seq("RETURNS"): 268 return self.raise_error("Expected 'RETURNS'") 269 return_comments = self._prev_comments 270 return_type = self._parse_returns() 271 return_type.comments = return_comments 272 273 # Parse AS 274 if not self._match(TokenType.ALIAS): 275 return self.raise_error("Expected 'AS'") 276 277 if is_function: 278 # Parse function body 279 body = self._parse_select() 280 else: 281 if str(return_type.this).upper() == "_PROJECTIONFRAGMENT": 282 body = self._parse_projections() 283 else: 284 body = self._parse_expression() 285 286 return exp.Create( 287 this=udf, 288 kind="PERFETTO FUNCTION" if is_function else "PERFETTO MACRO", 289 expression=body, 290 properties=self.expression( 291 exp.Properties, 292 expressions=[ 293 return_type, 294 ], 295 ), 296 ) 297 298 # Parse view/table name 299 table = self._parse_table(schema=True) 300 301 # Parse AS 302 if not self._match(TokenType.ALIAS): 303 return self.raise_error("Expected 'AS'") 304 305 # Parse SELECT statement 306 select = self._parse_select() 307 308 return exp.Create( 309 this=table, 310 kind='PERFETTO VIEW' if is_view else 'PERFETTO TABLE', 311 expression=select) 312 313 def _parse_var_override(self): 314 assert self._prev 315 if self._prev.text.upper() != "INCLUDE": 316 return self.raise_error("Expected 'INCLUDE'") 317 318 # Expect 'PERFETTO' 319 if not self._match_text_seq('PERFETTO'): 320 return self.raise_error("Expected 'PERFETTO'") 321 322 # Expect 'MODULE' 323 if not self._match_text_seq('MODULE'): 324 return self.raise_error("Expected 'MODULE'") 325 326 # Parse the module path (e.g. android.suspend) 327 module_path = [] 328 while True: 329 id = self._parse_id_var() 330 if not id: 331 break 332 module_path.append(id.text(key="this")) 333 if not self._match(TokenType.DOT): 334 break 335 return exp.Command(this="INCLUDE PERFETTO MODULE " + 336 '.'.join(module_path)) 337 338 def _parse_create_virtual_table_override(self): 339 if not self._match(TokenType.TABLE): 340 return self.raise_error("Expected 'TABLE'") 341 342 name = self._parse_id_var() 343 344 if not self._match_text_seq('USING'): 345 return self.raise_error("Expected 'USING'") 346 347 sp = self._parse_id_var() 348 self._match_l_paren() 349 start = self._prev 350 351 while not self._match(TokenType.R_PAREN): 352 self._advance() 353 354 return exp.Command( 355 this="CREATE VIRTUAL TABLE " + name.text(key="this") + " USING " + 356 sp.text(key="this") + " " + self._find_sql(start, self._prev),) 357 358 def _parse_case(self) -> t.Optional[exp.Expression]: 359 ifs = [] 360 default = None 361 362 comments = self._prev_comments 363 expression = self._parse_assignment() 364 365 while self._match(TokenType.WHEN): 366 when_comments = self._prev_comments 367 this = self._parse_assignment() 368 self._match(TokenType.THEN) 369 then = self._parse_assignment() 370 ifs.append( 371 self.expression( 372 exp.If, this=this, true=then, comments=when_comments)) 373 374 if self._match(TokenType.ELSE): 375 default = self._parse_assignment() 376 377 if not self._match(TokenType.END): 378 if isinstance(default, 379 exp.Interval) and default.this.sql().upper() == "END": 380 default = exp.column("interval") 381 else: 382 self.raise_error("Expected END after CASE", self._prev) 383 384 return self.expression( 385 exp.Case, 386 comments=comments, 387 this=expression, 388 ifs=ifs, 389 default=default) 390 391 def _parse_is( 392 self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 393 index = self._index - 1 394 negate = self._match(TokenType.NOT) 395 396 if self._match_text_seq("DISTINCT", "FROM"): 397 klass = exp.NullSafeEQ if negate else exp.NullSafeNEQ 398 return self.expression( 399 klass, this=this, expression=self._parse_bitwise()) 400 401 if self._match(TokenType.JSON): 402 kind = self._match_texts( 403 self.IS_JSON_PREDICATE_KIND) and self._prev.text.upper() 404 405 if self._match_text_seq("WITH"): 406 _with = True 407 elif self._match_text_seq("WITHOUT"): 408 _with = False 409 else: 410 _with = None 411 412 unique = self._match(TokenType.UNIQUE) 413 self._match_text_seq("KEYS") 414 expression: t.Optional[exp.Expression] = self.expression( 415 exp.JSON, **{ 416 "this": kind, 417 "with": _with, 418 "unique": unique 419 }) 420 else: 421 expression = self._parse_expression() 422 if not expression: 423 self._retreat(index) 424 return None 425 426 this = self.expression(exp.Is, this=this, expression=expression) 427 return self.expression(exp.Not, this=this) if negate else this 428 429 430def preprocess_macros(sql: str) -> Tuple[str, List[Tuple[str, str]]]: 431 """Convert macro calls to placeholders for sqlglot parsing. 432 433 Args: 434 sql: Input SQL string 435 436 Returns: 437 Tuple of (processed SQL, list of (placeholder, original macro text) pairs) 438 """ 439 result = sql 440 macros = [] 441 442 for match in re.finditer(r'(\w+)\s*!\s*\(', sql): 443 start = match.start() 444 i = match.end() 445 paren = 1 446 while i < len(sql) and paren > 0: 447 if sql[i] == '(': 448 paren += 1 449 elif sql[i] == ')': 450 paren -= 1 451 i += 1 452 453 if paren == 0: 454 macro = sql[start:i] 455 placeholder = f'__macro_{len(macros)}__' 456 macros.append((placeholder, macro)) 457 result = result.replace(macro, placeholder) 458 459 return result, macros 460 461 462def postprocess_macros(sql: str, macros: List[Tuple[str, str]]) -> str: 463 """Restore macro calls from their placeholders. 464 465 Args: 466 sql: Formatted SQL with placeholders 467 macros: List of (placeholder, original macro text) pairs from preprocess_macros 468 469 Returns: 470 SQL with macros restored 471 """ 472 result = sql 473 for placeholder, macro_text in macros: 474 result = result.replace(placeholder, macro_text) 475 return result 476 477 478def extract_comment_blocks(sql: str) -> Tuple[str, Dict[str, str]]: 479 """Extract comment blocks from SQL and replace with placeholders. 480 481 A comment block is defined as one or more comment lines (starting with --) 482 that are surrounded by empty lines or start/end of file. 483 484 Args: 485 sql: Input SQL string 486 487 Returns: 488 A tuple containing: 489 - Processed SQL with comment blocks replaced by placeholders 490 - Dict mapping placeholders to their original comment blocks 491 """ 492 # Split into chunks separated by empty lines 493 chunks = [] 494 current = [] 495 for line in sql.splitlines(): 496 if line.strip(): 497 current.append(line) 498 elif current: 499 chunks.append(current) 500 current = [] 501 if current: 502 chunks.append(current) 503 504 # Process each chunk 505 blocks = {} 506 result = [] 507 for i, chunk in enumerate(chunks): 508 # A chunk is a comment block if all lines start with -- 509 if all(line.strip().startswith('--') for line in chunk): 510 placeholder = f'-- __COMMENT_BLOCK_{i}__' 511 blocks[placeholder] = '\n'.join(chunk) + '\n' 512 result.append('') 513 result.append(placeholder) 514 result.append('') 515 else: 516 result.append('') 517 result.extend(chunk) 518 result.append('') 519 520 return '\n'.join(result).strip(), blocks 521 522 523def restore_comment_blocks(sql: str, blocks: Dict[str, str]) -> str: 524 """Restore comment blocks from their placeholders. 525 526 Args: 527 sql: SQL string with placeholders 528 blocks: Dict mapping placeholders to original comment blocks 529 530 Returns: 531 SQL string with comment blocks restored in their original positions 532 """ 533 result = sql 534 for placeholder, block in blocks.items(): 535 result = result.replace(placeholder, block) 536 return result 537 538 539def format_sql(file: Path, 540 sql: str, 541 indent_width: int = 2, 542 verbose: bool = False) -> str: 543 """Format SQL content with consistent style. 544 545 Args: 546 file: Path to the SQL file (for error reporting) 547 sql: SQL content to format 548 indent_width: Number of spaces for indentation 549 verbose: Whether to print status messages 550 551 Returns: 552 Formatted SQL string 553 554 Raises: 555 Exception: If SQL parsing or formatting fails 556 """ 557 if sql.find('-- sqlformat file off') != -1: 558 if verbose: 559 print(f"Ignoring {file}", file=sys.stderr) 560 return sql 561 562 # First extract comment blocks 563 sql_with_placeholders, comment_blocks = extract_comment_blocks(sql) 564 565 # Then process macros 566 processed, macros = preprocess_macros(sql_with_placeholders) 567 try: 568 formatted = '' 569 for ast in sqlglot.parse(sql=processed, dialect=Perfetto): 570 formatted += ast.sql( 571 pretty=True, 572 dialect=Perfetto, 573 indent=indent_width, 574 normalize=True, 575 normalize_functions='lower', 576 ) 577 formatted += ";\n\n" 578 579 # Restore macros first, then comment blocks 580 with_macros = postprocess_macros(formatted, macros) 581 return restore_comment_blocks(with_macros, comment_blocks).rstrip() + '\n' 582 except Exception as e: 583 print(f"Failed to format SQL: file {file}, {e}", file=sys.stderr) 584 raise e 585 586 587def format_files_in_place(paths: List[Union[str, Path]], 588 indent_width: int = 2, 589 verbose: bool = False) -> None: 590 """Format multiple SQL files in place. 591 592 Args: 593 paths: List of file or directory paths to format 594 indent_width: Number of spaces for indentation 595 verbose: Whether to print status messages 596 """ 597 for path in paths: 598 path = Path(path) 599 if path.is_dir(): 600 # Process all .sql files in directory recursively 601 sql_files = path.rglob('*.sql') 602 else: 603 # Single file 604 sql_files = [path] 605 606 for sql_file in sql_files: 607 with open(sql_file) as f: 608 sql = f.read() 609 formatted = format_sql(sql_file, sql, indent_width, verbose) 610 with open(sql_file, 'w') as f: 611 f.write(formatted) 612 if verbose: 613 print(f"Formatted {sql_file}", file=sys.stderr) 614 615 616def check_sql_formatting(paths: List[Union[str, Path]], 617 indent_width: int = 2, 618 verbose: bool = False) -> bool: 619 """Check SQL files for formatting violations without making changes. 620 621 Args: 622 paths: List of file or directory paths to check 623 indent_width: Number of spaces for indentation 624 verbose: Whether to print status messages 625 626 Returns: 627 True if all files are properly formatted, False otherwise 628 """ 629 all_formatted = True 630 for path in paths: 631 path = Path(path) 632 if path.is_dir(): 633 sql_files = path.rglob('*.sql') 634 else: 635 sql_files = [path] 636 637 for sql_file in sql_files: 638 with open(sql_file) as f: 639 sql = f.read() 640 formatted = format_sql(sql_file, sql, indent_width, verbose) 641 if formatted != sql: 642 print(f"Would format {sql_file}", file=sys.stderr) 643 all_formatted = False 644 645 return all_formatted 646 647 648def main() -> None: 649 """Main entry point.""" 650 if not SQLGLOT_DIR.exists(): 651 print(f"{SQLGLOT_DIR} does not exist. Run tools/install-build-deps first.") 652 sys.exit(1) 653 654 parser = argparse.ArgumentParser( 655 description='Format SQL queries with consistent style') 656 parser.add_argument( 657 'paths', 658 nargs='*', 659 help='Paths to SQL files or directories containing SQL files') 660 parser.add_argument( 661 '--indent-width', 662 type=int, 663 default=2, 664 help='Number of spaces for indentation (default: 2)') 665 parser.add_argument( 666 '--in-place', action='store_true', help='Format files in place') 667 parser.add_argument( 668 '--check-only', 669 action='store_true', 670 help='Check for formatting violations without making changes') 671 parser.add_argument( 672 '--verbose', 673 action='store_true', 674 help='Print status messages during execution') 675 args = parser.parse_args() 676 677 if args.check_only: 678 properly_formatted = check_sql_formatting(args.paths, args.indent_width, 679 args.verbose) 680 sys.exit(0 if properly_formatted else 1) 681 elif args.in_place: 682 format_files_in_place(args.paths, args.indent_width, args.verbose) 683 else: 684 # Read from stdin if no files provided 685 if not sys.stdin.isatty(): 686 sql_input = sys.stdin.read() 687 formatted_sql = format_sql(Path("stdin"), sql_input, args.indent_width) 688 print(formatted_sql) 689 else: 690 # Print help if no input provided 691 parser.print_help() 692 693 694if __name__ == '__main__': 695 main() 696