• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Module related to code analysis and generation."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19from ctypes import util  # pylint: disable=unused-import
20import glob  # pylint: disable=unused-import
21import itertools
22import os
23import sys  # pylint: disable=unused-import
24# pylint: disable=unused-import
25from typing import (
26    Text,
27    List,
28    Optional,
29    Set,
30    Dict,
31    Callable,
32    IO,
33    Generator as Gen,
34    Tuple,
35    Union,
36    Sequence,
37)
38# pylint: enable=unused-import
39
40# Use Python bindings to libclang that are in installed on the system.
41for p in glob.glob(
42    # Default system path on Debian
43    '/usr/lib/python*/dist-packages'
44) + glob.glob(
45    # Fedora and others
46    '/usr/lib/python*/site-packages'
47):
48  sys.path.append(p)
49
50from clang import cindex
51
52
53_PARSE_OPTIONS = (
54    cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
55    | cindex.TranslationUnit.PARSE_INCOMPLETE |
56    # for include directives
57    cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD
58)
59
60
61def get_header_guard(path):
62  # type: (Text) -> Text
63  """Generates header guard string from path."""
64  # the output file will be most likely somewhere in genfiles, strip the
65  # prefix in that case, also strip .gen if this is a step before clang-format
66  if not path:
67    raise ValueError('Cannot prepare header guard from path: {}'.format(path))
68  if 'genfiles/' in path:
69    path = path.split('genfiles/')[1]
70  if path.endswith('.gen'):
71    path = path.split('.gen')[0]
72  path = path.upper().replace('.', '_').replace('-', '_').replace('/', '_')
73  return path + '_'
74
75
76def _stringify_tokens(tokens, separator='\n'):
77  # type: (Sequence[cindex.Token], Text) -> Text
78  """Converts tokens to text respecting line position (disrespecting column)."""
79  previous = OutputLine(0, [])  # not used in output
80  lines = []  # type: List[OutputLine]
81
82  for _, group in itertools.groupby(tokens, lambda t: t.location.line):
83    group_list = list(group)
84    line = OutputLine(previous.next_tab, group_list)
85
86    lines.append(line)
87    previous = line
88
89  return separator.join(str(l) for l in lines)
90
91
92TYPE_MAPPING = {
93    cindex.TypeKind.VOID: '::sapi::v::Void',
94    cindex.TypeKind.CHAR_S: '::sapi::v::Char',
95    cindex.TypeKind.CHAR_U: '::sapi::v::Char',
96    cindex.TypeKind.INT: '::sapi::v::Int',
97    cindex.TypeKind.UINT: '::sapi::v::UInt',
98    cindex.TypeKind.LONG: '::sapi::v::Long',
99    cindex.TypeKind.ULONG: '::sapi::v::ULong',
100    cindex.TypeKind.UCHAR: '::sapi::v::UChar',
101    cindex.TypeKind.USHORT: '::sapi::v::UShort',
102    cindex.TypeKind.SHORT: '::sapi::v::Short',
103    cindex.TypeKind.LONGLONG: '::sapi::v::LLong',
104    cindex.TypeKind.ULONGLONG: '::sapi::v::ULLong',
105    cindex.TypeKind.FLOAT: '::sapi::v::Reg<float>',
106    cindex.TypeKind.DOUBLE: '::sapi::v::Reg<double>',
107    cindex.TypeKind.LONGDOUBLE: '::sapi::v::Reg<long double>',
108    cindex.TypeKind.SCHAR: '::sapi::v::SChar',
109    cindex.TypeKind.BOOL: '::sapi::v::Bool',
110}
111
112
113class Type(object):
114  """Class representing a type.
115
116  Wraps cindex.Type of the argument/return value and provides helpers for the
117  code generation.
118  """
119
120  def __init__(self, tu, clang_type):
121    # type: (_TranslationUnit, cindex.Type) -> None
122    self._clang_type = clang_type
123    self._tu = tu
124
125  # pylint: disable=protected-access
126  def __eq__(self, other):
127    # type: (Type) -> bool
128    # Use get_usr() to deduplicate Type objects based on declaration
129    decl = self._get_declaration()
130    decl_o = other._get_declaration()
131
132    return decl.get_usr() == decl_o.get_usr()
133
134  def __ne__(self, other):
135    # type: (Type) -> bool
136    return not self.__eq__(other)
137
138  def __lt__(self, other):
139    # type: (Type) -> bool
140    """Compares two Types belonging to the same TranslationUnit.
141
142    This is being used to properly order types before emitting to generated
143    file. To be more specific: structure definition that contains field that is
144    a typedef should end up after that typedef definition. This is achieved by
145    exploiting the order in which clang iterate over AST in translation unit.
146
147    Args:
148      other: other comparison type
149
150    Returns:
151      true if this Type occurs earlier in the AST than 'other'
152    """
153    self._validate_tu(other)
154    return (self._tu.order[self._get_declaration().hash] <
155            self._tu.order[other._get_declaration().hash])  # pylint: disable=protected-access
156
157  def __gt__(self, other):
158    # type: (Type) -> bool
159    """Compares two Types belonging to the same TranslationUnit.
160
161    This is being used to properly order types before emitting to generated
162    file. To be more specific: structure definition that contains field that is
163    a typedef should end up after that typedef definition. This is achieved by
164    exploiting the order in which clang iterate over AST in translation unit.
165
166    Args:
167      other: other comparison type
168
169    Returns:
170      true if this Type occurs later in the AST than 'other'
171    """
172    self._validate_tu(other)
173    return (self._tu.order[self._get_declaration().hash] >
174            self._tu.order[other._get_declaration().hash])  # pylint: disable=protected-access
175
176  def __hash__(self):
177    """Types with the same declaration should hash to the same value."""
178    return hash(self._get_declaration().get_usr())
179
180  def _validate_tu(self, other):
181    # type: (Type) -> None
182    if self._tu != other._tu:  # pylint: disable=protected-access
183      raise ValueError('Cannot compare types from different translation units.')
184
185  def is_void(self):
186    # type: () -> bool
187    return self._clang_type.kind == cindex.TypeKind.VOID
188
189  def is_typedef(self):
190    # type: () -> bool
191    return self._clang_type.kind == cindex.TypeKind.TYPEDEF
192
193  def is_elaborated(self):
194    # type: () -> bool
195    return self._clang_type.kind == cindex.TypeKind.ELABORATED
196
197  # Hack: both class and struct types are indistinguishable except for
198  # declaration cursor kind
199  def is_sugared_record(self):  # class, struct, union
200    # type: () -> bool
201    return self._clang_type.get_declaration().kind in (
202        cindex.CursorKind.STRUCT_DECL, cindex.CursorKind.UNION_DECL,
203        cindex.CursorKind.CLASS_DECL)
204
205  def is_struct(self):
206    # type: () -> bool
207    return (self._clang_type.get_declaration().kind ==
208            cindex.CursorKind.STRUCT_DECL)
209
210  def is_class(self):
211    # type: () -> bool
212    return (self._clang_type.get_declaration().kind ==
213            cindex.CursorKind.CLASS_DECL)
214
215  def is_union(self):
216    # type: () -> bool
217    return (self._clang_type.get_declaration().kind ==
218            cindex.CursorKind.UNION_DECL)
219
220  def is_function(self):
221    # type: () -> bool
222    return self._clang_type.kind == cindex.TypeKind.FUNCTIONPROTO
223
224  def is_sugared_ptr(self):
225    # type: () -> bool
226    return self._clang_type.get_canonical().kind == cindex.TypeKind.POINTER
227
228  def is_sugared_enum(self):
229    # type: () -> bool
230    return self._clang_type.get_canonical().kind == cindex.TypeKind.ENUM
231
232  def is_const_array(self):
233    # type: () -> bool
234    return self._clang_type.kind == cindex.TypeKind.CONSTANTARRAY
235
236  def is_simple_type(self):
237    # type: () -> bool
238    return self._clang_type.kind in TYPE_MAPPING
239
240  def get_pointee(self):
241    # type: () -> Type
242    return Type(self._tu, self._clang_type.get_pointee())
243
244  def _get_declaration(self):
245    # type: () -> cindex.Cursor
246    decl = self._clang_type.get_declaration()
247    if decl.kind == cindex.CursorKind.NO_DECL_FOUND and self.is_sugared_ptr():
248      decl = self.get_pointee()._get_declaration()  # pylint: disable=protected-access
249
250    return decl
251
252  def get_related_types(self, result=None, skip_self=False):
253    # type: (Optional[Set[Type]], bool) -> Set[Type]
254    """Returns all types related to this one eg. typedefs, nested structs."""
255    if result is None:
256      result = set()
257
258    # Base case.
259    if self in result or self.is_simple_type() or self.is_class():
260      return result
261
262    # Sugar types.
263    if self.is_typedef():
264      return self._get_related_types_of_typedef(result)
265
266    if self.is_elaborated():
267      return Type(self._tu,
268                  self._clang_type.get_named_type()).get_related_types(
269                      result, skip_self)
270
271    # Composite types.
272    if self.is_const_array():
273      t = Type(self._tu, self._clang_type.get_array_element_type())
274      return t.get_related_types(result)
275
276    if self._clang_type.kind in (cindex.TypeKind.POINTER,
277                                 cindex.TypeKind.MEMBERPOINTER,
278                                 cindex.TypeKind.LVALUEREFERENCE,
279                                 cindex.TypeKind.RVALUEREFERENCE):
280      return self.get_pointee().get_related_types(result, skip_self)
281
282    # union + struct, class should be filtered out
283    if self.is_struct() or self.is_union():
284      return self._get_related_types_of_record(result, skip_self)
285
286    if self.is_function():
287      return self._get_related_types_of_function(result)
288
289    if self.is_sugared_enum():
290      if not skip_self:
291        result.add(self)
292        self._tu.search_for_macro_name(self._get_declaration())
293      return result
294
295    # Ignore all cindex.TypeKind.UNEXPOSED AST nodes
296    # TODO(b/256934562): Remove the disable once the pytype bug is fixed.
297    return result  # pytype: disable=bad-return-type
298
299  def _get_related_types_of_typedef(self, result):
300    # type: (Set[Type]) -> Set[Type]
301    """Returns all intermediate types related to the typedef."""
302    result.add(self)
303    decl = self._clang_type.get_declaration()
304    self._tu.search_for_macro_name(decl)
305
306    t = Type(self._tu, decl.underlying_typedef_type)
307    if t.is_sugared_ptr():
308      t = t.get_pointee()
309
310    if not t.is_simple_type():
311      skip_child = self.contains_declaration(t)
312      if t.is_sugared_record() and skip_child:
313        # if child declaration is contained in parent, we don't have to emit it
314        self._tu.types_to_skip.add(t)
315      result.update(t.get_related_types(result, skip_child))
316
317    return result
318
319  def _get_related_types_of_record(self, result, skip_self=False):
320    # type: (Set[Type], bool) -> Set[Type]
321    """Returns all types related to the structure."""
322    # skip unnamed structures eg. typedef struct {...} x;
323    # struct {...} will be rendered as part of typedef rendering
324    decl = self._get_declaration()
325    if not decl.is_anonymous() and not skip_self:
326      self._tu.search_for_macro_name(decl)
327      result.add(self)
328
329    for f in self._clang_type.get_fields():
330      self._tu.search_for_macro_name(f)
331      result.update(Type(self._tu, f.type).get_related_types(result))
332
333    return result
334
335  def _get_related_types_of_function(self, result):
336    # type: (Set[Type]) -> Set[Type]
337    """Returns all types related to the function."""
338    for arg in self._clang_type.argument_types():
339      result.update(Type(self._tu, arg).get_related_types(result))
340    related = Type(self._tu,
341                   self._clang_type.get_result()).get_related_types(result)
342    result.update(related)
343
344    return result
345
346  def contains_declaration(self, other):
347    # type: (Type) -> bool
348    """Checks if string representation of a type contains the other type."""
349    self_extent = self._get_declaration().extent
350    other_extent = other._get_declaration().extent  # pylint: disable=protected-access
351
352    if other_extent.start.file is None:
353      return False
354    return (other_extent.start in self_extent and
355            other_extent.end in self_extent)
356
357  def stringify(self):
358    # type: () -> Text
359    """Returns string representation of the Type."""
360    # (szwl): as simple as possible, keeps macros in separate lines not to
361    # break things; this will go through clang format nevertheless
362    tokens = [
363        x for x in self._get_declaration().get_tokens()
364        if x.kind is not cindex.TokenKind.COMMENT
365    ]
366
367    return _stringify_tokens(tokens)
368
369
370class OutputLine(object):
371  """Helper class for Type printing."""
372
373  def __init__(self, tab, tokens):
374    # type: (int, List[cindex.Token]) -> None
375    self.tokens = tokens
376    self.spellings = []
377    self.define = False
378    self.tab = tab
379    self.next_tab = tab
380    list(map(self._process_token, self.tokens))
381
382  def _process_token(self, t):
383    # type: (cindex.Token) -> None
384    """Processes a token, setting up internal states rel. to intendation."""
385    if t.spelling == '#':
386      self.define = True
387    elif t.spelling == '{':
388      self.next_tab += 1
389    elif t.spelling == '}':
390      self.tab -= 1
391      self.next_tab -= 1
392
393    is_bracket = t.spelling == '('
394    is_macro = len(self.spellings) == 1 and self.spellings[0] == '#'
395    if self.spellings and not is_bracket and not is_macro:
396      self.spellings.append(' ')
397    self.spellings.append(t.spelling)
398
399  def __str__(self):
400    # type: () -> Text
401    tabs = ('\t' * self.tab) if not self.define else ''
402    return tabs + ''.join(t for t in self.spellings)
403
404
405class ArgumentType(Type):
406  """Class representing function argument type.
407
408  Object fields are being used by the code template:
409  pos: argument position
410  type: string representation of the type
411  argument: string representation of the type as function argument
412  mapped_type: SAPI equivalent of the type
413  wrapped: wraps type in SAPI object constructor
414  call_argument: type (or it's sapi wrapper) used in function call
415  """
416
417  def __init__(self, function, pos, arg_type, name=None):
418    # type: (Function, int, cindex.Type, Optional[Text]) -> None
419    super(ArgumentType, self).__init__(function.translation_unit(), arg_type)
420    self._function = function
421
422    self.pos = pos
423    self.name = name or 'a{}'.format(pos)
424    self.type = arg_type.spelling
425
426    template = '{}' if self.is_sugared_ptr() else '&{}_'
427    self.call_argument = template.format(self.name)
428
429  def __str__(self):
430    # type: () -> Text
431    """Returns function argument prepared from the type."""
432    if self.is_sugared_ptr():
433      return '::sapi::v::Ptr* {}'.format(self.name)
434
435    return '{} {}'.format(self._clang_type.spelling, self.name)
436
437  @property
438  def wrapped(self):
439    # type: () -> Text
440    return '{} {name}_(({name}))'.format(self.mapped_type, name=self.name)
441
442  @property
443  def mapped_type(self):
444    # type: () -> Text
445    """Maps the type to its SAPI equivalent."""
446    if self.is_sugared_ptr():
447      # TODO(szwl): const ptrs do not play well with SAPI C++ API...
448      spelling = self._clang_type.spelling.replace('const', '')
449      return '::sapi::v::Reg<{}>'.format(spelling)
450
451    type_ = self._clang_type
452
453    if type_.kind == cindex.TypeKind.TYPEDEF:
454      type_ = self._clang_type.get_canonical()
455    if type_.kind == cindex.TypeKind.ELABORATED:
456      type_ = type_.get_canonical()
457    if type_.kind == cindex.TypeKind.ENUM:
458      return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling)
459    if type_.kind in [
460        cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY
461    ]:
462      return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling)
463
464    if type_.kind == cindex.TypeKind.LVALUEREFERENCE:
465      return 'LVALUEREFERENCE::NOT_SUPPORTED'
466
467    if type_.kind == cindex.TypeKind.RVALUEREFERENCE:
468      return 'RVALUEREFERENCE::NOT_SUPPORTED'
469
470    if type_.kind in [cindex.TypeKind.RECORD, cindex.TypeKind.ELABORATED]:
471      raise ValueError('Elaborate type (eg. struct) in mapped_type is not '
472                       'supported: function {}, arg {}, type {}, location {}'
473                       ''.format(self._function.name, self.pos,
474                                 self._clang_type.spelling,
475                                 self._function.cursor.location))
476
477    if type_.kind not in TYPE_MAPPING:
478      raise KeyError('Key {} does not exist in TYPE_MAPPING.'
479                     ' function {}, arg {}, type {}, location {}'
480                     ''.format(type_.kind, self._function.name, self.pos,
481                               self._clang_type.spelling,
482                               self._function.cursor.location))
483
484    return TYPE_MAPPING[type_.kind]
485
486
487class ReturnType(ArgumentType):
488  """Class representing function return type.
489
490     Attributes:
491       return_type: absl::StatusOr<T> where T is original return type, or
492                    absl::Status for functions returning void
493  """
494
495  def __init__(self, function, arg_type):
496    # type: (Function, cindex.Type) -> None
497    super(ReturnType, self).__init__(function, 0, arg_type, None)
498
499  def __str__(self):
500    # type: () -> Text
501    """Returns function return type prepared from the type."""
502    # TODO(szwl): const ptrs do not play well with SAPI C++ API...
503    spelling = self._clang_type.spelling.replace('const', '')
504    return_type = 'absl::StatusOr<{}>'.format(spelling)
505    return_type = 'absl::Status' if self.is_void() else return_type
506    return return_type
507
508
509class Function(object):
510  """Class representing SAPI-wrapped function used by the template.
511
512  Wraps Clang cursor object of kind FUNCTION_DECL and provides helpers to
513  aid code generation.
514  """
515
516  def __init__(self, tu, cursor):
517    # type: (_TranslationUnit, cindex.Cursor) -> None
518    self._tu = tu
519    self.cursor = cursor  # type: cindex.Index
520    self.name = cursor.spelling  # type: Text
521    self.result = ReturnType(self, cursor.result_type)
522    self.original_definition = '{} {}'.format(
523        cursor.result_type.spelling, self.cursor.displayname)  # type: Text
524
525    types = self.cursor.get_arguments()
526    self.argument_types = [
527        ArgumentType(self, i, t.type, t.spelling) for i, t in enumerate(types)
528    ]
529
530  def translation_unit(self):
531    # type: () -> _TranslationUnit
532    return self._tu
533
534  def arguments(self):
535    # type: () -> List[ArgumentType]
536    return self.argument_types
537
538  def call_arguments(self):
539    # type: () -> List[Text]
540    return [a.call_argument for a in self.argument_types]
541
542  def get_absolute_path(self):
543    # type: () -> Text
544    return self.cursor.location.file.name
545
546  def get_include_path(self, prefix):
547    # type: (Optional[Text]) -> Text
548    """Creates a proper include path."""
549    # TODO(szwl): sanity checks
550    # TODO(szwl): prefix 'utils/' and the path is '.../fileutils/...' case
551    if prefix and not prefix.endswith('/'):
552      prefix += '/'
553
554    if not prefix:
555      return self.get_absolute_path()
556    elif prefix in self.get_absolute_path():
557      return prefix + self.get_absolute_path().split(prefix)[-1]
558    return prefix + self.get_absolute_path().split('/')[-1]
559
560  def get_related_types(self, processed=None):
561    # type: (Optional[Set[Type]]) -> Set[Type]
562    result = self.result.get_related_types(processed)
563    for a in self.argument_types:
564      result.update(a.get_related_types(processed))
565
566    return result
567
568  def is_mangled(self):
569    # type: () -> bool
570    return self.cursor.mangled_name != self.cursor.spelling
571
572  def __hash__(self):
573    # type: () -> int
574    return hash(self.cursor.get_usr())
575
576  def __eq__(self, other):
577    # type: (Function) -> bool
578    return self.cursor.mangled_name == other.cursor.mangled_name
579
580
581class _TranslationUnit(object):
582  """Class wrapping clang's _TranslationUnit. Provides extra utilities."""
583
584  def __init__(self, path, tu, limit_scan_depth=False, func_names=None):
585    # type: (Text, cindex.TranslationUnit, bool, Optional[List[Text]]) -> None
586    """Initializes the translation unit.
587
588    Args:
589      path: path to source of the tranlation unit
590      tu: cindex tranlation unit
591      limit_scan_depth: whether scan should be limited to single file
592      func_names: list of function names to take into consideration, empty means
593        all functions.
594    """
595    self.path = path
596    self.limit_scan_depth = limit_scan_depth
597    self._tu = tu
598    self._processed = False
599    self.forward_decls = dict()
600    self.functions = set()
601    self.order = dict()
602    self.defines = {}
603    self.required_defines = set()
604    self.types_to_skip = set()
605    self.func_names = func_names or []
606
607  def _process(self):
608    # type: () -> None
609    """Walks the cursor tree and caches some for future use."""
610    if not self._processed:
611      # self.includes[self._tu.spelling] = (0, self._tu.cursor)
612      self._processed = True
613      # TODO(szwl): duplicates?
614      # TODO(szwl): for d in translation_unit.diagnostics:, handle that
615
616      for i, cursor in enumerate(self._walk_preorder()):
617        # Workaround for issue#32
618        # ignore all the cursors with kinds not implemented in python bindings
619        try:
620          cursor.kind
621        except ValueError:
622          continue
623        # naive way to order types: they should be ordered when walking the tree
624        if cursor.kind.is_declaration():
625          self.order[cursor.hash] = i
626
627        if (cursor.kind == cindex.CursorKind.MACRO_DEFINITION and
628            cursor.location.file):
629          self.order[cursor.hash] = i
630          self.defines[cursor.spelling] = cursor
631
632        # most likely a forward decl of struct
633        if (cursor.kind == cindex.CursorKind.STRUCT_DECL and
634            not cursor.is_definition()):
635          self.forward_decls[Type(self, cursor.type)] = cursor
636        if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and
637            cursor.linkage != cindex.LinkageKind.INTERNAL):
638          # Skip non-interesting functions
639          if self.func_names and cursor.spelling not in self.func_names:
640            continue
641          if self.limit_scan_depth:
642            if (cursor.location and cursor.location.file.name == self.path):
643              self.functions.add(Function(self, cursor))
644          else:
645            self.functions.add(Function(self, cursor))
646
647  def get_functions(self):
648    # type: () -> Set[Function]
649    self._process()
650    return self.functions
651
652  def _walk_preorder(self):
653    # type: () -> Gen
654    for c in self._tu.cursor.walk_preorder():
655      yield c
656
657  def search_for_macro_name(self, cursor):
658    # type: (cindex.Cursor) -> None
659    """Searches for possible macro usage in constant array types."""
660    tokens = list(t.spelling for t in cursor.get_tokens())
661    try:
662      for token in tokens:
663        if token in self.defines and token not in self.required_defines:
664          self.required_defines.add(token)
665          self.search_for_macro_name(self.defines[token])
666    except ValueError:
667      return
668
669
670class Analyzer(object):
671  """Class responsible for analysis."""
672
673  # pylint: disable=line-too-long
674  @staticmethod
675  def process_files(
676      input_paths, compile_flags, limit_scan_depth=False, func_names=None
677  ):
678    # type: (Text, List[Text], bool, Optional[List[Text]]) -> List[_TranslationUnit]
679    """Processes files with libclang and returns TranslationUnit objects."""
680
681    tus = []
682    for path in input_paths:
683      tu = Analyzer._analyze_file_for_tu(
684          path,
685          compile_flags=compile_flags,
686          limit_scan_depth=limit_scan_depth,
687          func_names=func_names,
688      )
689      tus.append(tu)
690    return tus
691
692  # pylint: disable=line-too-long
693  @staticmethod
694  def _analyze_file_for_tu(
695      path,
696      compile_flags=None,
697      test_file_existence=True,
698      unsaved_files=None,
699      limit_scan_depth=False,
700      func_names=None,
701  ):
702    # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool, Optional[List[Text]]) -> _TranslationUnit
703    """Returns Analysis object for given path."""
704    compile_flags = compile_flags or []
705    if test_file_existence and not os.path.isfile(path):
706      raise IOError('Path {} does not exist.'.format(path))
707
708    index = cindex.Index.create()  # type: cindex.Index
709    # TODO(szwl): hack until I figure out how python swig does that.
710    # Headers will be parsed as C++. C libs usually have
711    # '#ifdef __cplusplus extern "C"' for compatibility with c++
712    lang = '-xc++' if not path.endswith('.c') else '-xc'
713    args = [lang]
714    args += compile_flags
715    args.append('-I.')
716    return _TranslationUnit(
717        path,
718        index.parse(
719            path, args=args, unsaved_files=unsaved_files, options=_PARSE_OPTIONS
720        ),
721        limit_scan_depth=limit_scan_depth,
722        func_names=func_names,
723    )
724
725
726class Generator(object):
727  """Class responsible for code generation."""
728
729  AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n'
730                    '// Edits will be discarded when regenerating this file.\n')
731
732  GUARD_START = ('#ifndef {0}\n' '#define {0}')
733  GUARD_END = '#endif  // {}'
734  EMBED_INCLUDE = '#include "{}"'
735  EMBED_CLASS = '''
736class {0}Sandbox : public ::sapi::Sandbox {{
737  public:
738    {0}Sandbox()
739      : ::sapi::Sandbox([]() {{
740          static auto* fork_client_context =
741              new ::sapi::ForkClientContext({1}_embed_create());
742          return fork_client_context;
743        }}()) {{}}
744}};'''
745
746  def __init__(self, translation_units):
747    # type: (List[cindex.TranslationUnit]) -> None
748    """Initializes the generator.
749
750    Args:
751      translation_units: list of translation_units for analyzed files,
752        facultative. If not given, then one is computed for each element of
753        input_paths
754    """
755    self.translation_units = translation_units
756    self.functions = None
757
758  def generate(
759      self,
760      name,
761      namespace=None,
762      output_file=None,
763      embed_dir=None,
764      embed_name=None,
765  ):
766    # pylint: disable=line-too-long
767    # type: (Text, Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text
768    """Generates structures, functions and typedefs.
769
770    Args:
771      name: name of the class that will contain generated interface
772      namespace: namespace of the interface
773      output_file: path to the output file, used to generate header guards;
774        defaults to None that does not generate the guard #include directives;
775        defaults to None that causes to emit the whole file path
776      embed_dir: path to directory with embed includes
777      embed_name: name of the embed object
778
779    Returns:
780      generated interface as a string
781    """
782    related_types = self._get_related_types()
783    forward_decls = self._get_forward_decls(related_types)
784    functions = self._get_functions()
785    related_types = [(t.stringify() + ';') for t in related_types]
786    defines = self._get_defines()
787
788    api = {
789        'name': name,
790        'functions': functions,
791        'related_types': defines + forward_decls + related_types,
792        'namespaces': namespace.split('::') if namespace else [],
793        'embed_dir': embed_dir,
794        'embed_name': embed_name,
795        'output_file': output_file
796    }
797    return self.format_template(**api)
798
799  def _get_functions(self):
800    # type: () -> List[Function]
801    """Gets Function objects that will be used to generate interface."""
802    if self.functions is not None:
803      return self.functions
804    self.functions = []
805    # TODO(szwl): for d in translation_unit.diagnostics:, handle that
806    for translation_unit in self.translation_units:
807      self.functions += translation_unit.get_functions()
808    # allow only nonmangled functions - C++ overloads are not handled in
809    # code generation
810    self.functions = [f for f in self.functions if not f.is_mangled()]
811
812    # remove duplicates
813    self.functions = list(set(self.functions))
814    self.functions.sort(key=lambda x: x.name)
815    return self.functions
816
817  def _get_related_types(self):
818    # type: () -> List[Type]
819    """Gets type definitions related to chosen functions.
820
821    Types related to one function will land in the same translation unit,
822    we gather the types, sort it and put as a sublist in types list.
823    This is necessary as we can't compare types from two different translation
824    units.
825
826    Returns:
827      list of types in correct (ready to render) order
828    """
829    processed = set()
830    types = []
831    types_to_skip = set()
832
833    for f in self._get_functions():
834      fn_related_types = f.get_related_types()
835      types += sorted(r for r in fn_related_types if r not in processed)
836      processed.update(fn_related_types)
837      types_to_skip.update(f.translation_unit().types_to_skip)
838
839    return [t for t in types if t not in types_to_skip]
840
841  def _get_defines(self):
842    # type: () -> List[Text]
843    """Gets #define directives that appeared during TranslationUnit processing.
844
845    Returns:
846      list of #define string representations
847    """
848
849    def make_sort_condition(translation_unit):
850      return lambda cursor: translation_unit.order[cursor.hash]
851
852    result = []
853    for tu in self.translation_units:
854      tmp_result = []
855      sort_condition = make_sort_condition(tu)
856      for name in tu.required_defines:
857        if name in tu.defines:
858          define = tu.defines[name]
859          tmp_result.append(define)
860      for define in sorted(tmp_result, key=sort_condition):
861        result.append('#define ' +
862                      _stringify_tokens(define.get_tokens(), separator=' \\\n'))
863    return result
864
865  def _get_forward_decls(self, types):
866    # type: (List[Type]) -> List[Text]
867    """Gets forward declarations of related types, if present."""
868    forward_decls = dict()
869    result = []
870    done = set()
871    for tu in self.translation_units:
872      forward_decls.update(tu.forward_decls)
873
874      for t in types:
875        if t in forward_decls and t not in done:
876          result.append(_stringify_tokens(forward_decls[t].get_tokens()) + ';')
877          done.add(t)
878
879    return result
880
881  def _format_function(self, f):
882    # type: (Function) -> Text
883    """Renders one function of the Api.
884
885    Args:
886      f: function object with information necessary to emit full function body
887
888    Returns:
889      filled function template
890    """
891    result = []
892    result.append('  // {}'.format(f.original_definition))
893
894    arguments = ', '.join(str(a) for a in f.arguments())
895    result.append('  {} {}({}) {{'.format(f.result, f.name, arguments))
896    result.append('    {} ret;'.format(f.result.mapped_type))
897
898    argument_types = []
899    for a in f.argument_types:
900      if not a.is_sugared_ptr():
901        argument_types.append(a.wrapped + ';')
902    if argument_types:
903      for arg in argument_types:
904        result.append('    {}'.format(arg))
905
906    call_arguments = f.call_arguments()
907    if call_arguments:  # fake empty space to add ',' before first argument
908      call_arguments.insert(0, '')
909    result.append('')
910    # For OSS, the macro below will be replaced.
911    result.append('    SAPI_RETURN_IF_ERROR(sandbox_->Call("{}", &ret{}));'
912                  ''.format(f.name, ', '.join(call_arguments)))
913
914    return_status = 'return absl::OkStatus();'
915    if f.result and not f.result.is_void():
916      if f.result and f.result.is_sugared_enum():
917        return_status = ('return static_cast<{}>'
918                         '(ret.GetValue());').format(f.result.type)
919      else:
920        return_status = 'return ret.GetValue();'
921    result.append('    {}'.format(return_status))
922    result.append('  }')
923
924    return '\n'.join(result)
925
926  def format_template(self, name, functions, related_types, namespaces,
927                      embed_dir, embed_name, output_file):
928    # pylint: disable=line-too-long
929    # type: (Text, List[Function], List[Text], List[Text], Text, Text, Text) -> Text
930    # pylint: enable=line-too-long
931    """Formats arguments into proper interface header file.
932
933    Args:
934      name: name of the Api - 'Test' will yield TestApi object
935      functions: list of functions to generate
936      related_types: types used in the above functions
937      namespaces: list of namespaces to wrap the Api class with
938      embed_dir: directory where the embedded library lives
939      embed_name: name of embedded library
940      output_file: interface output path - used in header guard generation
941
942    Returns:
943      generated header file text
944    """
945    result = [Generator.AUTO_GENERATED]
946
947    header_guard = get_header_guard(output_file) if output_file else ''
948    if header_guard:
949      result.append(Generator.GUARD_START.format(header_guard))
950
951    # Copybara transform results in the paths below.
952    result.append('#include "absl/status/status.h"')
953    result.append('#include "absl/status/statusor.h"')
954    result.append('#include "sandboxed_api/sandbox.h"')
955    result.append('#include "sandboxed_api/util/status_macros.h"')
956    result.append('#include "sandboxed_api/vars.h"')
957
958    if embed_name:
959      embed_dir = embed_dir or ''
960      result.append(
961          Generator.EMBED_INCLUDE.format(
962              os.path.join(embed_dir, embed_name) + '_embed.h'))
963
964    if namespaces:
965      result.append('')
966      for n in namespaces:
967        result.append('namespace {} {{'.format(n))
968
969    if related_types:
970      result.append('')
971      for t in related_types:
972        result.append(t)
973
974    result.append('')
975
976    if embed_name:
977      result.append(
978          Generator.EMBED_CLASS.format(name, embed_name.replace('-', '_')))
979
980    result.append('class {}Api {{'.format(name))
981    result.append(' public:')
982    result.append('  explicit {}Api(::sapi::Sandbox* sandbox)'
983                  ' : sandbox_(sandbox) {{}}'.format(name))
984    result.append('  // Deprecated')
985    result.append('  ::sapi::Sandbox* GetSandbox() const { return sandbox(); }')
986    result.append('  ::sapi::Sandbox* sandbox() const { return sandbox_; }')
987
988    for f in functions:
989      result.append('')
990      result.append(self._format_function(f))
991
992    result.append('')
993    result.append(' private:')
994    result.append('  ::sapi::Sandbox* sandbox_;')
995    result.append('};')
996    result.append('')
997
998    if namespaces:
999      for n in reversed(namespaces):
1000        result.append('}}  // namespace {}'.format(n))
1001
1002    if header_guard:
1003      result.append(Generator.GUARD_END.format(header_guard))
1004
1005    result.append('')
1006
1007    return '\n'.join(result)
1008