1# Copyright 2019 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Module related to code analysis and generation.""" 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19from ctypes import util # pylint: disable=unused-import 20import glob # pylint: disable=unused-import 21import itertools 22import os 23import sys # pylint: disable=unused-import 24# pylint: disable=unused-import 25from typing import ( 26 Text, 27 List, 28 Optional, 29 Set, 30 Dict, 31 Callable, 32 IO, 33 Generator as Gen, 34 Tuple, 35 Union, 36 Sequence, 37) 38# pylint: enable=unused-import 39 40# Use Python bindings to libclang that are in installed on the system. 41for p in glob.glob( 42 # Default system path on Debian 43 '/usr/lib/python*/dist-packages' 44) + glob.glob( 45 # Fedora and others 46 '/usr/lib/python*/site-packages' 47): 48 sys.path.append(p) 49 50from clang import cindex 51 52 53_PARSE_OPTIONS = ( 54 cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES 55 | cindex.TranslationUnit.PARSE_INCOMPLETE | 56 # for include directives 57 cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD 58) 59 60 61def get_header_guard(path): 62 # type: (Text) -> Text 63 """Generates header guard string from path.""" 64 # the output file will be most likely somewhere in genfiles, strip the 65 # prefix in that case, also strip .gen if this is a step before clang-format 66 if not path: 67 raise ValueError('Cannot prepare header guard from path: {}'.format(path)) 68 if 'genfiles/' in path: 69 path = path.split('genfiles/')[1] 70 if path.endswith('.gen'): 71 path = path.split('.gen')[0] 72 path = path.upper().replace('.', '_').replace('-', '_').replace('/', '_') 73 return path + '_' 74 75 76def _stringify_tokens(tokens, separator='\n'): 77 # type: (Sequence[cindex.Token], Text) -> Text 78 """Converts tokens to text respecting line position (disrespecting column).""" 79 previous = OutputLine(0, []) # not used in output 80 lines = [] # type: List[OutputLine] 81 82 for _, group in itertools.groupby(tokens, lambda t: t.location.line): 83 group_list = list(group) 84 line = OutputLine(previous.next_tab, group_list) 85 86 lines.append(line) 87 previous = line 88 89 return separator.join(str(l) for l in lines) 90 91 92TYPE_MAPPING = { 93 cindex.TypeKind.VOID: '::sapi::v::Void', 94 cindex.TypeKind.CHAR_S: '::sapi::v::Char', 95 cindex.TypeKind.CHAR_U: '::sapi::v::Char', 96 cindex.TypeKind.INT: '::sapi::v::Int', 97 cindex.TypeKind.UINT: '::sapi::v::UInt', 98 cindex.TypeKind.LONG: '::sapi::v::Long', 99 cindex.TypeKind.ULONG: '::sapi::v::ULong', 100 cindex.TypeKind.UCHAR: '::sapi::v::UChar', 101 cindex.TypeKind.USHORT: '::sapi::v::UShort', 102 cindex.TypeKind.SHORT: '::sapi::v::Short', 103 cindex.TypeKind.LONGLONG: '::sapi::v::LLong', 104 cindex.TypeKind.ULONGLONG: '::sapi::v::ULLong', 105 cindex.TypeKind.FLOAT: '::sapi::v::Reg<float>', 106 cindex.TypeKind.DOUBLE: '::sapi::v::Reg<double>', 107 cindex.TypeKind.LONGDOUBLE: '::sapi::v::Reg<long double>', 108 cindex.TypeKind.SCHAR: '::sapi::v::SChar', 109 cindex.TypeKind.BOOL: '::sapi::v::Bool', 110} 111 112 113class Type(object): 114 """Class representing a type. 115 116 Wraps cindex.Type of the argument/return value and provides helpers for the 117 code generation. 118 """ 119 120 def __init__(self, tu, clang_type): 121 # type: (_TranslationUnit, cindex.Type) -> None 122 self._clang_type = clang_type 123 self._tu = tu 124 125 # pylint: disable=protected-access 126 def __eq__(self, other): 127 # type: (Type) -> bool 128 # Use get_usr() to deduplicate Type objects based on declaration 129 decl = self._get_declaration() 130 decl_o = other._get_declaration() 131 132 return decl.get_usr() == decl_o.get_usr() 133 134 def __ne__(self, other): 135 # type: (Type) -> bool 136 return not self.__eq__(other) 137 138 def __lt__(self, other): 139 # type: (Type) -> bool 140 """Compares two Types belonging to the same TranslationUnit. 141 142 This is being used to properly order types before emitting to generated 143 file. To be more specific: structure definition that contains field that is 144 a typedef should end up after that typedef definition. This is achieved by 145 exploiting the order in which clang iterate over AST in translation unit. 146 147 Args: 148 other: other comparison type 149 150 Returns: 151 true if this Type occurs earlier in the AST than 'other' 152 """ 153 self._validate_tu(other) 154 return (self._tu.order[self._get_declaration().hash] < 155 self._tu.order[other._get_declaration().hash]) # pylint: disable=protected-access 156 157 def __gt__(self, other): 158 # type: (Type) -> bool 159 """Compares two Types belonging to the same TranslationUnit. 160 161 This is being used to properly order types before emitting to generated 162 file. To be more specific: structure definition that contains field that is 163 a typedef should end up after that typedef definition. This is achieved by 164 exploiting the order in which clang iterate over AST in translation unit. 165 166 Args: 167 other: other comparison type 168 169 Returns: 170 true if this Type occurs later in the AST than 'other' 171 """ 172 self._validate_tu(other) 173 return (self._tu.order[self._get_declaration().hash] > 174 self._tu.order[other._get_declaration().hash]) # pylint: disable=protected-access 175 176 def __hash__(self): 177 """Types with the same declaration should hash to the same value.""" 178 return hash(self._get_declaration().get_usr()) 179 180 def _validate_tu(self, other): 181 # type: (Type) -> None 182 if self._tu != other._tu: # pylint: disable=protected-access 183 raise ValueError('Cannot compare types from different translation units.') 184 185 def is_void(self): 186 # type: () -> bool 187 return self._clang_type.kind == cindex.TypeKind.VOID 188 189 def is_typedef(self): 190 # type: () -> bool 191 return self._clang_type.kind == cindex.TypeKind.TYPEDEF 192 193 def is_elaborated(self): 194 # type: () -> bool 195 return self._clang_type.kind == cindex.TypeKind.ELABORATED 196 197 # Hack: both class and struct types are indistinguishable except for 198 # declaration cursor kind 199 def is_sugared_record(self): # class, struct, union 200 # type: () -> bool 201 return self._clang_type.get_declaration().kind in ( 202 cindex.CursorKind.STRUCT_DECL, cindex.CursorKind.UNION_DECL, 203 cindex.CursorKind.CLASS_DECL) 204 205 def is_struct(self): 206 # type: () -> bool 207 return (self._clang_type.get_declaration().kind == 208 cindex.CursorKind.STRUCT_DECL) 209 210 def is_class(self): 211 # type: () -> bool 212 return (self._clang_type.get_declaration().kind == 213 cindex.CursorKind.CLASS_DECL) 214 215 def is_union(self): 216 # type: () -> bool 217 return (self._clang_type.get_declaration().kind == 218 cindex.CursorKind.UNION_DECL) 219 220 def is_function(self): 221 # type: () -> bool 222 return self._clang_type.kind == cindex.TypeKind.FUNCTIONPROTO 223 224 def is_sugared_ptr(self): 225 # type: () -> bool 226 return self._clang_type.get_canonical().kind == cindex.TypeKind.POINTER 227 228 def is_sugared_enum(self): 229 # type: () -> bool 230 return self._clang_type.get_canonical().kind == cindex.TypeKind.ENUM 231 232 def is_const_array(self): 233 # type: () -> bool 234 return self._clang_type.kind == cindex.TypeKind.CONSTANTARRAY 235 236 def is_simple_type(self): 237 # type: () -> bool 238 return self._clang_type.kind in TYPE_MAPPING 239 240 def get_pointee(self): 241 # type: () -> Type 242 return Type(self._tu, self._clang_type.get_pointee()) 243 244 def _get_declaration(self): 245 # type: () -> cindex.Cursor 246 decl = self._clang_type.get_declaration() 247 if decl.kind == cindex.CursorKind.NO_DECL_FOUND and self.is_sugared_ptr(): 248 decl = self.get_pointee()._get_declaration() # pylint: disable=protected-access 249 250 return decl 251 252 def get_related_types(self, result=None, skip_self=False): 253 # type: (Optional[Set[Type]], bool) -> Set[Type] 254 """Returns all types related to this one eg. typedefs, nested structs.""" 255 if result is None: 256 result = set() 257 258 # Base case. 259 if self in result or self.is_simple_type() or self.is_class(): 260 return result 261 262 # Sugar types. 263 if self.is_typedef(): 264 return self._get_related_types_of_typedef(result) 265 266 if self.is_elaborated(): 267 return Type(self._tu, 268 self._clang_type.get_named_type()).get_related_types( 269 result, skip_self) 270 271 # Composite types. 272 if self.is_const_array(): 273 t = Type(self._tu, self._clang_type.get_array_element_type()) 274 return t.get_related_types(result) 275 276 if self._clang_type.kind in (cindex.TypeKind.POINTER, 277 cindex.TypeKind.MEMBERPOINTER, 278 cindex.TypeKind.LVALUEREFERENCE, 279 cindex.TypeKind.RVALUEREFERENCE): 280 return self.get_pointee().get_related_types(result, skip_self) 281 282 # union + struct, class should be filtered out 283 if self.is_struct() or self.is_union(): 284 return self._get_related_types_of_record(result, skip_self) 285 286 if self.is_function(): 287 return self._get_related_types_of_function(result) 288 289 if self.is_sugared_enum(): 290 if not skip_self: 291 result.add(self) 292 self._tu.search_for_macro_name(self._get_declaration()) 293 return result 294 295 # Ignore all cindex.TypeKind.UNEXPOSED AST nodes 296 # TODO(b/256934562): Remove the disable once the pytype bug is fixed. 297 return result # pytype: disable=bad-return-type 298 299 def _get_related_types_of_typedef(self, result): 300 # type: (Set[Type]) -> Set[Type] 301 """Returns all intermediate types related to the typedef.""" 302 result.add(self) 303 decl = self._clang_type.get_declaration() 304 self._tu.search_for_macro_name(decl) 305 306 t = Type(self._tu, decl.underlying_typedef_type) 307 if t.is_sugared_ptr(): 308 t = t.get_pointee() 309 310 if not t.is_simple_type(): 311 skip_child = self.contains_declaration(t) 312 if t.is_sugared_record() and skip_child: 313 # if child declaration is contained in parent, we don't have to emit it 314 self._tu.types_to_skip.add(t) 315 result.update(t.get_related_types(result, skip_child)) 316 317 return result 318 319 def _get_related_types_of_record(self, result, skip_self=False): 320 # type: (Set[Type], bool) -> Set[Type] 321 """Returns all types related to the structure.""" 322 # skip unnamed structures eg. typedef struct {...} x; 323 # struct {...} will be rendered as part of typedef rendering 324 decl = self._get_declaration() 325 if not decl.is_anonymous() and not skip_self: 326 self._tu.search_for_macro_name(decl) 327 result.add(self) 328 329 for f in self._clang_type.get_fields(): 330 self._tu.search_for_macro_name(f) 331 result.update(Type(self._tu, f.type).get_related_types(result)) 332 333 return result 334 335 def _get_related_types_of_function(self, result): 336 # type: (Set[Type]) -> Set[Type] 337 """Returns all types related to the function.""" 338 for arg in self._clang_type.argument_types(): 339 result.update(Type(self._tu, arg).get_related_types(result)) 340 related = Type(self._tu, 341 self._clang_type.get_result()).get_related_types(result) 342 result.update(related) 343 344 return result 345 346 def contains_declaration(self, other): 347 # type: (Type) -> bool 348 """Checks if string representation of a type contains the other type.""" 349 self_extent = self._get_declaration().extent 350 other_extent = other._get_declaration().extent # pylint: disable=protected-access 351 352 if other_extent.start.file is None: 353 return False 354 return (other_extent.start in self_extent and 355 other_extent.end in self_extent) 356 357 def stringify(self): 358 # type: () -> Text 359 """Returns string representation of the Type.""" 360 # (szwl): as simple as possible, keeps macros in separate lines not to 361 # break things; this will go through clang format nevertheless 362 tokens = [ 363 x for x in self._get_declaration().get_tokens() 364 if x.kind is not cindex.TokenKind.COMMENT 365 ] 366 367 return _stringify_tokens(tokens) 368 369 370class OutputLine(object): 371 """Helper class for Type printing.""" 372 373 def __init__(self, tab, tokens): 374 # type: (int, List[cindex.Token]) -> None 375 self.tokens = tokens 376 self.spellings = [] 377 self.define = False 378 self.tab = tab 379 self.next_tab = tab 380 list(map(self._process_token, self.tokens)) 381 382 def _process_token(self, t): 383 # type: (cindex.Token) -> None 384 """Processes a token, setting up internal states rel. to intendation.""" 385 if t.spelling == '#': 386 self.define = True 387 elif t.spelling == '{': 388 self.next_tab += 1 389 elif t.spelling == '}': 390 self.tab -= 1 391 self.next_tab -= 1 392 393 is_bracket = t.spelling == '(' 394 is_macro = len(self.spellings) == 1 and self.spellings[0] == '#' 395 if self.spellings and not is_bracket and not is_macro: 396 self.spellings.append(' ') 397 self.spellings.append(t.spelling) 398 399 def __str__(self): 400 # type: () -> Text 401 tabs = ('\t' * self.tab) if not self.define else '' 402 return tabs + ''.join(t for t in self.spellings) 403 404 405class ArgumentType(Type): 406 """Class representing function argument type. 407 408 Object fields are being used by the code template: 409 pos: argument position 410 type: string representation of the type 411 argument: string representation of the type as function argument 412 mapped_type: SAPI equivalent of the type 413 wrapped: wraps type in SAPI object constructor 414 call_argument: type (or it's sapi wrapper) used in function call 415 """ 416 417 def __init__(self, function, pos, arg_type, name=None): 418 # type: (Function, int, cindex.Type, Optional[Text]) -> None 419 super(ArgumentType, self).__init__(function.translation_unit(), arg_type) 420 self._function = function 421 422 self.pos = pos 423 self.name = name or 'a{}'.format(pos) 424 self.type = arg_type.spelling 425 426 template = '{}' if self.is_sugared_ptr() else '&{}_' 427 self.call_argument = template.format(self.name) 428 429 def __str__(self): 430 # type: () -> Text 431 """Returns function argument prepared from the type.""" 432 if self.is_sugared_ptr(): 433 return '::sapi::v::Ptr* {}'.format(self.name) 434 435 return '{} {}'.format(self._clang_type.spelling, self.name) 436 437 @property 438 def wrapped(self): 439 # type: () -> Text 440 return '{} {name}_(({name}))'.format(self.mapped_type, name=self.name) 441 442 @property 443 def mapped_type(self): 444 # type: () -> Text 445 """Maps the type to its SAPI equivalent.""" 446 if self.is_sugared_ptr(): 447 # TODO(szwl): const ptrs do not play well with SAPI C++ API... 448 spelling = self._clang_type.spelling.replace('const', '') 449 return '::sapi::v::Reg<{}>'.format(spelling) 450 451 type_ = self._clang_type 452 453 if type_.kind == cindex.TypeKind.TYPEDEF: 454 type_ = self._clang_type.get_canonical() 455 if type_.kind == cindex.TypeKind.ELABORATED: 456 type_ = type_.get_canonical() 457 if type_.kind == cindex.TypeKind.ENUM: 458 return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling) 459 if type_.kind in [ 460 cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY 461 ]: 462 return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling) 463 464 if type_.kind == cindex.TypeKind.LVALUEREFERENCE: 465 return 'LVALUEREFERENCE::NOT_SUPPORTED' 466 467 if type_.kind == cindex.TypeKind.RVALUEREFERENCE: 468 return 'RVALUEREFERENCE::NOT_SUPPORTED' 469 470 if type_.kind in [cindex.TypeKind.RECORD, cindex.TypeKind.ELABORATED]: 471 raise ValueError('Elaborate type (eg. struct) in mapped_type is not ' 472 'supported: function {}, arg {}, type {}, location {}' 473 ''.format(self._function.name, self.pos, 474 self._clang_type.spelling, 475 self._function.cursor.location)) 476 477 if type_.kind not in TYPE_MAPPING: 478 raise KeyError('Key {} does not exist in TYPE_MAPPING.' 479 ' function {}, arg {}, type {}, location {}' 480 ''.format(type_.kind, self._function.name, self.pos, 481 self._clang_type.spelling, 482 self._function.cursor.location)) 483 484 return TYPE_MAPPING[type_.kind] 485 486 487class ReturnType(ArgumentType): 488 """Class representing function return type. 489 490 Attributes: 491 return_type: absl::StatusOr<T> where T is original return type, or 492 absl::Status for functions returning void 493 """ 494 495 def __init__(self, function, arg_type): 496 # type: (Function, cindex.Type) -> None 497 super(ReturnType, self).__init__(function, 0, arg_type, None) 498 499 def __str__(self): 500 # type: () -> Text 501 """Returns function return type prepared from the type.""" 502 # TODO(szwl): const ptrs do not play well with SAPI C++ API... 503 spelling = self._clang_type.spelling.replace('const', '') 504 return_type = 'absl::StatusOr<{}>'.format(spelling) 505 return_type = 'absl::Status' if self.is_void() else return_type 506 return return_type 507 508 509class Function(object): 510 """Class representing SAPI-wrapped function used by the template. 511 512 Wraps Clang cursor object of kind FUNCTION_DECL and provides helpers to 513 aid code generation. 514 """ 515 516 def __init__(self, tu, cursor): 517 # type: (_TranslationUnit, cindex.Cursor) -> None 518 self._tu = tu 519 self.cursor = cursor # type: cindex.Index 520 self.name = cursor.spelling # type: Text 521 self.result = ReturnType(self, cursor.result_type) 522 self.original_definition = '{} {}'.format( 523 cursor.result_type.spelling, self.cursor.displayname) # type: Text 524 525 types = self.cursor.get_arguments() 526 self.argument_types = [ 527 ArgumentType(self, i, t.type, t.spelling) for i, t in enumerate(types) 528 ] 529 530 def translation_unit(self): 531 # type: () -> _TranslationUnit 532 return self._tu 533 534 def arguments(self): 535 # type: () -> List[ArgumentType] 536 return self.argument_types 537 538 def call_arguments(self): 539 # type: () -> List[Text] 540 return [a.call_argument for a in self.argument_types] 541 542 def get_absolute_path(self): 543 # type: () -> Text 544 return self.cursor.location.file.name 545 546 def get_include_path(self, prefix): 547 # type: (Optional[Text]) -> Text 548 """Creates a proper include path.""" 549 # TODO(szwl): sanity checks 550 # TODO(szwl): prefix 'utils/' and the path is '.../fileutils/...' case 551 if prefix and not prefix.endswith('/'): 552 prefix += '/' 553 554 if not prefix: 555 return self.get_absolute_path() 556 elif prefix in self.get_absolute_path(): 557 return prefix + self.get_absolute_path().split(prefix)[-1] 558 return prefix + self.get_absolute_path().split('/')[-1] 559 560 def get_related_types(self, processed=None): 561 # type: (Optional[Set[Type]]) -> Set[Type] 562 result = self.result.get_related_types(processed) 563 for a in self.argument_types: 564 result.update(a.get_related_types(processed)) 565 566 return result 567 568 def is_mangled(self): 569 # type: () -> bool 570 return self.cursor.mangled_name != self.cursor.spelling 571 572 def __hash__(self): 573 # type: () -> int 574 return hash(self.cursor.get_usr()) 575 576 def __eq__(self, other): 577 # type: (Function) -> bool 578 return self.cursor.mangled_name == other.cursor.mangled_name 579 580 581class _TranslationUnit(object): 582 """Class wrapping clang's _TranslationUnit. Provides extra utilities.""" 583 584 def __init__(self, path, tu, limit_scan_depth=False, func_names=None): 585 # type: (Text, cindex.TranslationUnit, bool, Optional[List[Text]]) -> None 586 """Initializes the translation unit. 587 588 Args: 589 path: path to source of the tranlation unit 590 tu: cindex tranlation unit 591 limit_scan_depth: whether scan should be limited to single file 592 func_names: list of function names to take into consideration, empty means 593 all functions. 594 """ 595 self.path = path 596 self.limit_scan_depth = limit_scan_depth 597 self._tu = tu 598 self._processed = False 599 self.forward_decls = dict() 600 self.functions = set() 601 self.order = dict() 602 self.defines = {} 603 self.required_defines = set() 604 self.types_to_skip = set() 605 self.func_names = func_names or [] 606 607 def _process(self): 608 # type: () -> None 609 """Walks the cursor tree and caches some for future use.""" 610 if not self._processed: 611 # self.includes[self._tu.spelling] = (0, self._tu.cursor) 612 self._processed = True 613 # TODO(szwl): duplicates? 614 # TODO(szwl): for d in translation_unit.diagnostics:, handle that 615 616 for i, cursor in enumerate(self._walk_preorder()): 617 # Workaround for issue#32 618 # ignore all the cursors with kinds not implemented in python bindings 619 try: 620 cursor.kind 621 except ValueError: 622 continue 623 # naive way to order types: they should be ordered when walking the tree 624 if cursor.kind.is_declaration(): 625 self.order[cursor.hash] = i 626 627 if (cursor.kind == cindex.CursorKind.MACRO_DEFINITION and 628 cursor.location.file): 629 self.order[cursor.hash] = i 630 self.defines[cursor.spelling] = cursor 631 632 # most likely a forward decl of struct 633 if (cursor.kind == cindex.CursorKind.STRUCT_DECL and 634 not cursor.is_definition()): 635 self.forward_decls[Type(self, cursor.type)] = cursor 636 if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and 637 cursor.linkage != cindex.LinkageKind.INTERNAL): 638 # Skip non-interesting functions 639 if self.func_names and cursor.spelling not in self.func_names: 640 continue 641 if self.limit_scan_depth: 642 if (cursor.location and cursor.location.file.name == self.path): 643 self.functions.add(Function(self, cursor)) 644 else: 645 self.functions.add(Function(self, cursor)) 646 647 def get_functions(self): 648 # type: () -> Set[Function] 649 self._process() 650 return self.functions 651 652 def _walk_preorder(self): 653 # type: () -> Gen 654 for c in self._tu.cursor.walk_preorder(): 655 yield c 656 657 def search_for_macro_name(self, cursor): 658 # type: (cindex.Cursor) -> None 659 """Searches for possible macro usage in constant array types.""" 660 tokens = list(t.spelling for t in cursor.get_tokens()) 661 try: 662 for token in tokens: 663 if token in self.defines and token not in self.required_defines: 664 self.required_defines.add(token) 665 self.search_for_macro_name(self.defines[token]) 666 except ValueError: 667 return 668 669 670class Analyzer(object): 671 """Class responsible for analysis.""" 672 673 # pylint: disable=line-too-long 674 @staticmethod 675 def process_files( 676 input_paths, compile_flags, limit_scan_depth=False, func_names=None 677 ): 678 # type: (Text, List[Text], bool, Optional[List[Text]]) -> List[_TranslationUnit] 679 """Processes files with libclang and returns TranslationUnit objects.""" 680 681 tus = [] 682 for path in input_paths: 683 tu = Analyzer._analyze_file_for_tu( 684 path, 685 compile_flags=compile_flags, 686 limit_scan_depth=limit_scan_depth, 687 func_names=func_names, 688 ) 689 tus.append(tu) 690 return tus 691 692 # pylint: disable=line-too-long 693 @staticmethod 694 def _analyze_file_for_tu( 695 path, 696 compile_flags=None, 697 test_file_existence=True, 698 unsaved_files=None, 699 limit_scan_depth=False, 700 func_names=None, 701 ): 702 # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool, Optional[List[Text]]) -> _TranslationUnit 703 """Returns Analysis object for given path.""" 704 compile_flags = compile_flags or [] 705 if test_file_existence and not os.path.isfile(path): 706 raise IOError('Path {} does not exist.'.format(path)) 707 708 index = cindex.Index.create() # type: cindex.Index 709 # TODO(szwl): hack until I figure out how python swig does that. 710 # Headers will be parsed as C++. C libs usually have 711 # '#ifdef __cplusplus extern "C"' for compatibility with c++ 712 lang = '-xc++' if not path.endswith('.c') else '-xc' 713 args = [lang] 714 args += compile_flags 715 args.append('-I.') 716 return _TranslationUnit( 717 path, 718 index.parse( 719 path, args=args, unsaved_files=unsaved_files, options=_PARSE_OPTIONS 720 ), 721 limit_scan_depth=limit_scan_depth, 722 func_names=func_names, 723 ) 724 725 726class Generator(object): 727 """Class responsible for code generation.""" 728 729 AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n' 730 '// Edits will be discarded when regenerating this file.\n') 731 732 GUARD_START = ('#ifndef {0}\n' '#define {0}') 733 GUARD_END = '#endif // {}' 734 EMBED_INCLUDE = '#include "{}"' 735 EMBED_CLASS = ''' 736class {0}Sandbox : public ::sapi::Sandbox {{ 737 public: 738 {0}Sandbox() 739 : ::sapi::Sandbox([]() {{ 740 static auto* fork_client_context = 741 new ::sapi::ForkClientContext({1}_embed_create()); 742 return fork_client_context; 743 }}()) {{}} 744}};''' 745 746 def __init__(self, translation_units): 747 # type: (List[cindex.TranslationUnit]) -> None 748 """Initializes the generator. 749 750 Args: 751 translation_units: list of translation_units for analyzed files, 752 facultative. If not given, then one is computed for each element of 753 input_paths 754 """ 755 self.translation_units = translation_units 756 self.functions = None 757 758 def generate( 759 self, 760 name, 761 namespace=None, 762 output_file=None, 763 embed_dir=None, 764 embed_name=None, 765 ): 766 # pylint: disable=line-too-long 767 # type: (Text, Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text 768 """Generates structures, functions and typedefs. 769 770 Args: 771 name: name of the class that will contain generated interface 772 namespace: namespace of the interface 773 output_file: path to the output file, used to generate header guards; 774 defaults to None that does not generate the guard #include directives; 775 defaults to None that causes to emit the whole file path 776 embed_dir: path to directory with embed includes 777 embed_name: name of the embed object 778 779 Returns: 780 generated interface as a string 781 """ 782 related_types = self._get_related_types() 783 forward_decls = self._get_forward_decls(related_types) 784 functions = self._get_functions() 785 related_types = [(t.stringify() + ';') for t in related_types] 786 defines = self._get_defines() 787 788 api = { 789 'name': name, 790 'functions': functions, 791 'related_types': defines + forward_decls + related_types, 792 'namespaces': namespace.split('::') if namespace else [], 793 'embed_dir': embed_dir, 794 'embed_name': embed_name, 795 'output_file': output_file 796 } 797 return self.format_template(**api) 798 799 def _get_functions(self): 800 # type: () -> List[Function] 801 """Gets Function objects that will be used to generate interface.""" 802 if self.functions is not None: 803 return self.functions 804 self.functions = [] 805 # TODO(szwl): for d in translation_unit.diagnostics:, handle that 806 for translation_unit in self.translation_units: 807 self.functions += translation_unit.get_functions() 808 # allow only nonmangled functions - C++ overloads are not handled in 809 # code generation 810 self.functions = [f for f in self.functions if not f.is_mangled()] 811 812 # remove duplicates 813 self.functions = list(set(self.functions)) 814 self.functions.sort(key=lambda x: x.name) 815 return self.functions 816 817 def _get_related_types(self): 818 # type: () -> List[Type] 819 """Gets type definitions related to chosen functions. 820 821 Types related to one function will land in the same translation unit, 822 we gather the types, sort it and put as a sublist in types list. 823 This is necessary as we can't compare types from two different translation 824 units. 825 826 Returns: 827 list of types in correct (ready to render) order 828 """ 829 processed = set() 830 types = [] 831 types_to_skip = set() 832 833 for f in self._get_functions(): 834 fn_related_types = f.get_related_types() 835 types += sorted(r for r in fn_related_types if r not in processed) 836 processed.update(fn_related_types) 837 types_to_skip.update(f.translation_unit().types_to_skip) 838 839 return [t for t in types if t not in types_to_skip] 840 841 def _get_defines(self): 842 # type: () -> List[Text] 843 """Gets #define directives that appeared during TranslationUnit processing. 844 845 Returns: 846 list of #define string representations 847 """ 848 849 def make_sort_condition(translation_unit): 850 return lambda cursor: translation_unit.order[cursor.hash] 851 852 result = [] 853 for tu in self.translation_units: 854 tmp_result = [] 855 sort_condition = make_sort_condition(tu) 856 for name in tu.required_defines: 857 if name in tu.defines: 858 define = tu.defines[name] 859 tmp_result.append(define) 860 for define in sorted(tmp_result, key=sort_condition): 861 result.append('#define ' + 862 _stringify_tokens(define.get_tokens(), separator=' \\\n')) 863 return result 864 865 def _get_forward_decls(self, types): 866 # type: (List[Type]) -> List[Text] 867 """Gets forward declarations of related types, if present.""" 868 forward_decls = dict() 869 result = [] 870 done = set() 871 for tu in self.translation_units: 872 forward_decls.update(tu.forward_decls) 873 874 for t in types: 875 if t in forward_decls and t not in done: 876 result.append(_stringify_tokens(forward_decls[t].get_tokens()) + ';') 877 done.add(t) 878 879 return result 880 881 def _format_function(self, f): 882 # type: (Function) -> Text 883 """Renders one function of the Api. 884 885 Args: 886 f: function object with information necessary to emit full function body 887 888 Returns: 889 filled function template 890 """ 891 result = [] 892 result.append(' // {}'.format(f.original_definition)) 893 894 arguments = ', '.join(str(a) for a in f.arguments()) 895 result.append(' {} {}({}) {{'.format(f.result, f.name, arguments)) 896 result.append(' {} ret;'.format(f.result.mapped_type)) 897 898 argument_types = [] 899 for a in f.argument_types: 900 if not a.is_sugared_ptr(): 901 argument_types.append(a.wrapped + ';') 902 if argument_types: 903 for arg in argument_types: 904 result.append(' {}'.format(arg)) 905 906 call_arguments = f.call_arguments() 907 if call_arguments: # fake empty space to add ',' before first argument 908 call_arguments.insert(0, '') 909 result.append('') 910 # For OSS, the macro below will be replaced. 911 result.append(' SAPI_RETURN_IF_ERROR(sandbox_->Call("{}", &ret{}));' 912 ''.format(f.name, ', '.join(call_arguments))) 913 914 return_status = 'return absl::OkStatus();' 915 if f.result and not f.result.is_void(): 916 if f.result and f.result.is_sugared_enum(): 917 return_status = ('return static_cast<{}>' 918 '(ret.GetValue());').format(f.result.type) 919 else: 920 return_status = 'return ret.GetValue();' 921 result.append(' {}'.format(return_status)) 922 result.append(' }') 923 924 return '\n'.join(result) 925 926 def format_template(self, name, functions, related_types, namespaces, 927 embed_dir, embed_name, output_file): 928 # pylint: disable=line-too-long 929 # type: (Text, List[Function], List[Text], List[Text], Text, Text, Text) -> Text 930 # pylint: enable=line-too-long 931 """Formats arguments into proper interface header file. 932 933 Args: 934 name: name of the Api - 'Test' will yield TestApi object 935 functions: list of functions to generate 936 related_types: types used in the above functions 937 namespaces: list of namespaces to wrap the Api class with 938 embed_dir: directory where the embedded library lives 939 embed_name: name of embedded library 940 output_file: interface output path - used in header guard generation 941 942 Returns: 943 generated header file text 944 """ 945 result = [Generator.AUTO_GENERATED] 946 947 header_guard = get_header_guard(output_file) if output_file else '' 948 if header_guard: 949 result.append(Generator.GUARD_START.format(header_guard)) 950 951 # Copybara transform results in the paths below. 952 result.append('#include "absl/status/status.h"') 953 result.append('#include "absl/status/statusor.h"') 954 result.append('#include "sandboxed_api/sandbox.h"') 955 result.append('#include "sandboxed_api/util/status_macros.h"') 956 result.append('#include "sandboxed_api/vars.h"') 957 958 if embed_name: 959 embed_dir = embed_dir or '' 960 result.append( 961 Generator.EMBED_INCLUDE.format( 962 os.path.join(embed_dir, embed_name) + '_embed.h')) 963 964 if namespaces: 965 result.append('') 966 for n in namespaces: 967 result.append('namespace {} {{'.format(n)) 968 969 if related_types: 970 result.append('') 971 for t in related_types: 972 result.append(t) 973 974 result.append('') 975 976 if embed_name: 977 result.append( 978 Generator.EMBED_CLASS.format(name, embed_name.replace('-', '_'))) 979 980 result.append('class {}Api {{'.format(name)) 981 result.append(' public:') 982 result.append(' explicit {}Api(::sapi::Sandbox* sandbox)' 983 ' : sandbox_(sandbox) {{}}'.format(name)) 984 result.append(' // Deprecated') 985 result.append(' ::sapi::Sandbox* GetSandbox() const { return sandbox(); }') 986 result.append(' ::sapi::Sandbox* sandbox() const { return sandbox_; }') 987 988 for f in functions: 989 result.append('') 990 result.append(self._format_function(f)) 991 992 result.append('') 993 result.append(' private:') 994 result.append(' ::sapi::Sandbox* sandbox_;') 995 result.append('};') 996 result.append('') 997 998 if namespaces: 999 for n in reversed(namespaces): 1000 result.append('}} // namespace {}'.format(n)) 1001 1002 if header_guard: 1003 result.append(Generator.GUARD_END.format(header_guard)) 1004 1005 result.append('') 1006 1007 return '\n'.join(result) 1008