• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3# Copyright 2024 The Chromium Authors
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7import argparse
8import collections
9import copy
10import os
11import pathlib
12import sys
13import typing
14import re
15import dataclasses
16
17
18def _GetDirAbove(dirname: str):
19  """Returns the directory "above" this file containing |dirname| (which must
20  also be "above" this file)."""
21  path = os.path.abspath(__file__)
22  while True:
23    path, tail = os.path.split(path)
24    if not tail:
25      return None
26    if tail == dirname:
27      return path
28
29
30SOURCE_DIR = _GetDirAbove('testing')
31
32sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party'))
33sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party/domato/src'))
34sys.path.append(os.path.join(SOURCE_DIR, 'build'))
35
36import action_helpers
37import jinja2
38import grammar
39
40# TODO(crbug.com/361369290): Remove this disable once DomatoLPM development is
41# finished and upstream changes can be made to expose the relevant protected
42# fields.
43# pylint: disable=protected-access
44
45def to_snake_case(name):
46  name = re.sub(r'([A-Z]{2,})([A-Z][a-z])', r'\1_\2', name)
47  return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', name, sys.maxsize).lower()
48
49
50DOMATO_INT_TYPE_TO_CPP_INT_TYPE = {
51    'int': 'int',
52    'int32': 'int32_t',
53    'uint32': 'uint32_t',
54    'int8': 'int8_t',
55    'uint8': 'uint8_t',
56    'int16': 'int16_t',
57    'uint16': 'uint16_t',
58    'int64': 'uint64_t',
59    'uint64': 'uint64_t',
60}
61
62DOMATO_TO_PROTO_BUILT_IN = {
63    'int': 'int32',
64    'int32': 'int32',
65    'uint32': 'uint32',
66    'int8': 'int32',
67    'uint8': 'uint32',
68    'int16': 'int16',
69    'uint16': 'uint16',
70    'int64': 'int64',
71    'uint64': 'uint64',
72    'float': 'float',
73    'double': 'double',
74    'char': 'int32',
75    'string': 'string',
76    'htmlsafestring': 'string',
77    'hex': 'int32',
78    'lines': 'repeated lines',
79}
80
81DOMATO_TO_CPP_HANDLERS = {
82    'int': 'handle_int_conversion<int32_t, int>',
83    'int32': 'handle_int_conversion<int32_t, int32_t>',
84    'uint32': 'handle_int_conversion<uint32_t, uint32_t>',
85    'int8': 'handle_int_conversion<int32_t, int8_t>',
86    'uint8': 'handle_int_conversion<uint32_t, uint8_t>',
87    'int16': 'handle_int_conversion<int16_t, int16_t>',
88    'uint16': 'handle_int_conversion<uint16_t, uint16_t>',
89    'int64': 'handle_int_conversion<int64_t, int64_t>',
90    'uint64': 'handle_int_conversion<uint64_t, uint64_t>',
91    'float': 'handle_float',
92    'double': 'handle_double',
93    'char': 'handle_char',
94    'string': 'handle_string',
95    'htmlsafestring': 'handle_string',
96    'hex': 'handle_hex',
97}
98
99_C_STR_TRANS = str.maketrans({
100    '\n': '\\n',
101    '\r': '\\r',
102    '\t': '\\t',
103    '\"': '\\\"',
104    '\\': '\\\\'
105})
106
107BASE_PROTO_NS = 'domatolpm.generated'
108
109
110def to_cpp_ns(proto_ns: str) -> str:
111  return proto_ns.replace('.', '::')
112
113
114CPP_HANDLER_PREFIX = 'handle_'
115
116
117def to_proto_field_name(name: str) -> str:
118  """Converts a creator or rule name to a proto field name. This tries to
119  respect the protobuf naming convention that field names should be snake case.
120
121  Args:
122      name: the name of the creator or the rule.
123
124  Returns:
125      the proto field name to use.
126  """
127  res = to_snake_case(name.replace('-', '_'))
128  if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']:
129    res += '_proto'
130  return res
131
132
133def to_proto_type(creator_name: str) -> str:
134  """Converts a creator name to a proto type. This is deliberately very simple
135  so that we avoid naming conflicts.
136
137  Args:
138      creator_name: the name of the creator.
139
140  Returns:
141      the name of the proto type.
142  """
143  res = creator_name.replace('-', '_')
144  if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']:
145    res += '_proto'
146  return res
147
148
149def c_escape(v: str) -> str:
150  return v.translate(_C_STR_TRANS)
151
152
153@dataclasses.dataclass
154class ProtoType:
155  """Represents a Proto type."""
156  name: str
157
158  def is_one_of(self) -> bool:
159    return False
160
161
162@dataclasses.dataclass
163class ProtoField:
164  """Represents a proto message field."""
165  type: ProtoType
166  name: str
167  proto_id: int
168
169
170@dataclasses.dataclass
171class ProtoMessage(ProtoType):
172  """Represents a Proto message."""
173  fields: typing.List[ProtoField]
174
175
176@dataclasses.dataclass
177class OneOfProtoMessage(ProtoMessage):
178  """Represents a Proto message with a oneof field."""
179  oneofname: str
180
181  def is_one_of(self) -> bool:
182    return True
183
184
185class CppExpression:
186
187  def repr(self):
188    raise Exception('Not implemented.')
189
190
191@dataclasses.dataclass
192class CppTxtExpression(CppExpression):
193  """Represents a Raw text expression."""
194  content: str
195
196  def repr(self):
197    return self.content
198
199
200@dataclasses.dataclass
201class CppCallExpr(CppExpression):
202  """Represents a CallExpr."""
203  fct_name: str
204  args: typing.List[CppExpression]
205  ns: str = ''
206
207  def repr(self):
208    arg_s = ', '.join([a.repr() for a in self.args])
209    return f'{self.ns}{self.fct_name}({arg_s})'
210
211
212class CppHandlerCallExpr(CppCallExpr):
213
214  def __init__(self,
215               handler: str,
216               field_name: str,
217               extra_args: typing.Optional[typing.List[CppExpression]] = None):
218    args = [CppTxtExpression('ctx'), CppTxtExpression(f'arg.{field_name}()')]
219    if extra_args:
220      args += extra_args
221    super().__init__(fct_name=handler, args=args)
222    self.handler = handler
223    self.field_name = field_name
224    self.extra_args = extra_args
225
226
227@dataclasses.dataclass
228class CppStringExpr(CppExpression):
229  """Represents a C++ literal string.
230  """
231  content: str
232
233  def repr(self):
234    return f'\"{c_escape(self.content)}\"'
235
236
237@dataclasses.dataclass
238class CppFunctionHandler:
239  """Represents a C++ function.
240  """
241  name: str
242  exprs: typing.List[CppExpression]
243
244  def is_oneof_handler(self) -> bool:
245    return False
246
247  def is_string_table_handler(self) -> bool:
248    return False
249
250  def is_message_handler(self) -> bool:
251    return False
252
253
254class CppStringTableHandler(CppFunctionHandler):
255  """Represents a C++ function that implements a string table and returns one
256  of the represented strings.
257  """
258
259  def __init__(self, name: str, var_name: str,
260               strings: typing.List[CppStringExpr]):
261    super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[])
262    self.proto_type = f'{name}& arg'
263    self.strings = strings
264    self.var_name = var_name
265
266  def is_string_table_handler(self) -> bool:
267    return True
268
269
270class CppProtoMessageFunctionHandler(CppFunctionHandler):
271  """Represents a C++ function that handles a ProtoMessage.
272  """
273
274  def __init__(self,
275               name: str,
276               exprs: typing.List[CppExpression],
277               creator: typing.Optional[typing.Dict[str, str]] = None):
278    super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=exprs)
279    self.proto_type = f'{name}& arg'
280    self.creator = creator
281
282  def creates_new(self):
283    return self.creator is not None
284
285  def is_message_handler(self) -> bool:
286    return True
287
288
289class CppOneOfMessageFunctionHandler(CppFunctionHandler):
290  """Represents a C++ function that handles a OneOfProtoMessage.
291  """
292
293  def __init__(self, name: str, switch_name: str,
294               cases: typing.Dict[str, typing.List[CppExpression]]):
295    super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[])
296    self.proto_type = f'{name}& arg'
297    self.switch_name = switch_name
298    self.cases = cases
299
300  def all_except_last(self):
301    a = list(self.cases.keys())[:-1]
302    return {e: self.cases[e] for e in a}
303
304  def last(self):
305    a = list(self.cases.keys())[-1]
306    return self.cases[a]
307
308  def is_oneof_handler(self) -> bool:
309    return True
310
311
312class DomatoBuilder:
313  """DomatoBuilder is the class that takes a Domato grammar, and modelize it
314  into a protobuf representation and its corresponding C++ parsing code.
315  """
316
317  @dataclasses.dataclass
318  class Entry:
319    msg: ProtoMessage
320    func: CppFunctionHandler
321
322  def __init__(self, g: grammar.Grammar):
323    self.handlers: typing.Dict[str, DomatoBuilder.Entry] = {}
324    self.backrefs: typing.Dict[str, typing.List[str]] = {}
325    self.grammar = g
326    if self.grammar._root and self.grammar._root != 'root':
327      self.root = self.grammar._root
328    else:
329      self.root = 'lines'
330    if self.grammar._root and self.grammar._root == 'root':
331      rules = self.grammar._creators[self.grammar._root]
332      # multiple roots doesn't make sense, so we only consider the last defined
333      # one.
334      rule = rules[-1]
335      for part in rule['parts']:
336        if part['type'] == 'tag' and part[
337            'tagname'] == 'lines' and 'count' in part:
338          self.root = f'lines_{part["count"]}'
339          break
340    self._built_in_types_parser = {
341        'int': self._int_handler,
342        'int32': self._int_handler,
343        'uint32': self._int_handler,
344        'int8': self._int_handler,
345        'uint8': self._int_handler,
346        'int16': self._int_handler,
347        'uint16': self._int_handler,
348        'int64': self._int_handler,
349        'uint64': self._int_handler,
350        'float': self._default_handler,
351        'double': self._default_handler,
352        'char': self._default_handler,
353        'string': self._default_handler,
354        'htmlsafestring': self._default_handler,
355        'hex': self._default_handler,
356        'lines': self._lines_handler,
357    }
358
359  def parse_grammar(self):
360    for creator, rules in self.grammar._creators.items():
361      field_name = to_proto_field_name(creator)
362      type_name = to_proto_type(creator)
363      messages = self._parse_rule(creator, rules)
364      proto_fields: typing.List[ProtoField] = []
365      for proto_id, msg in enumerate(messages, start=1):
366        proto_fields.append(
367            ProtoField(type=ProtoType(name=msg.name),
368                       name=f'{field_name}_{proto_id}',
369                       proto_id=proto_id))
370      msg = OneOfProtoMessage(name=type_name,
371                              oneofname='oneoffield',
372                              fields=proto_fields)
373      cases = {
374          f.name: [
375              CppHandlerCallExpr(handler=f'{CPP_HANDLER_PREFIX}{f.type.name}',
376                                 field_name=f.name)
377          ]
378          for f in proto_fields
379      }
380      func = CppOneOfMessageFunctionHandler(name=type_name,
381                                            switch_name='oneoffield',
382                                            cases=cases)
383      self._add(msg, func)
384
385  def all_proto_messages(self):
386    return [v.msg for v in self.handlers.values()]
387
388  def all_cpp_functions(self):
389    return [v.func for v in self.handlers.values()]
390
391  def get_line_prefix(self) -> str:
392    if not self.grammar._line_guard:
393      return ''
394    return self.grammar._line_guard.split('<line>')[0]
395
396  def get_line_suffix(self) -> str:
397    if not self.grammar._line_guard:
398      return ''
399    return self.grammar._line_guard.split('<line>')[1]
400
401  def should_generate_repeated_lines(self):
402    return self.root == 'lines'
403
404  def should_generate_one_line_handler(self):
405    return self.root.startswith('lines')
406
407  def maybe_add_lines_handler(self, number: int) -> bool:
408    name = f'lines_{number}'
409    if name in self.handlers:
410      return False
411    fields = []
412    exprs = []
413    for i in range(1, number + 1):
414      fields.append(ProtoField(ProtoType('line'), f'line_{i}', i))
415      exprs.append(CppHandlerCallExpr('handle_one_line', f'line_{i}'))
416    msg = ProtoMessage(name, fields=fields)
417    handler = CppProtoMessageFunctionHandler(name, exprs=exprs)
418    self.handlers[name] = DomatoBuilder.Entry(msg, handler)
419    return True
420
421  def get_roots(self) -> typing.Tuple[ProtoMessage, CppFunctionHandler]:
422    root = self.root
423    root_handler = f'{CPP_HANDLER_PREFIX}{root}'
424    fuzz_case = ProtoMessage(
425        name='fuzzcase',
426        fields=[ProtoField(type=ProtoType(name=root), name='root', proto_id=1)])
427    fuzz_fct = CppProtoMessageFunctionHandler(
428        name='fuzzcase',
429        exprs=[CppHandlerCallExpr(handler=root_handler, field_name='root')])
430    return fuzz_case, fuzz_fct
431
432  def get_protos(self) -> typing.Tuple[typing.List[ProtoMessage]]:
433    if self.should_generate_one_line_handler():
434      # We're handling a code grammar.
435      roots = [v.msg for k, v in self.handlers.items() if k.startswith('line')]
436      roots.append(self.get_roots()[0])
437      non_roots = [
438          v.msg for k, v in self.handlers.items() if not k.startswith('line')
439      ]
440      return roots, non_roots
441    return [self.get_roots()[0]], self.all_proto_messages()
442
443  def simplify(self):
444    """Simplifies the proto and functions."""
445    should_continue = True
446    while should_continue:
447      should_continue = False
448      should_continue |= self._merge_unary_oneofs()
449      should_continue |= self._merge_strings()
450      should_continue |= self._merge_multistrings_oneofs()
451      should_continue |= self._remove_unlinked_nodes()
452      should_continue |= self._merge_proto_messages()
453      should_continue |= self._merge_oneofs()
454    self._oneofs_reorderer()
455    self._oneof_message_renamer()
456    self._message_renamer()
457
458  def _add(self, message: ProtoMessage,
459           handler: CppProtoMessageFunctionHandler):
460    self.handlers[message.name] = DomatoBuilder.Entry(message, handler)
461    for field in message.fields:
462      if not field.type.name in self.backrefs:
463        self.backrefs[field.type.name] = []
464      self.backrefs[field.type.name].append(message.name)
465
466  def _int_handler(
467      self, part,
468      field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]:
469    proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']]
470    handler = DOMATO_TO_CPP_HANDLERS[part['tagname']]
471    extra_args = []
472    if 'min' in part:
473      extra_args.append(CppTxtExpression(part['min']))
474    if 'max' in part:
475      if not extra_args:
476        cpp_type = DOMATO_INT_TYPE_TO_CPP_INT_TYPE[part['tagname']]
477        extra_args.append(
478            CppTxtExpression(f'std::numeric_limits<{cpp_type}>::min()'))
479      extra_args.append(CppTxtExpression(part['max']))
480    contents = CppHandlerCallExpr(handler=handler,
481                                  field_name=field_name,
482                                  extra_args=extra_args)
483    return proto_type, contents
484
485  def _lines_handler(
486      self, part,
487      field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]:
488    handler_name = 'lines'
489    if 'count' in part:
490      count = part['count']
491      handler_name = f'{handler_name}_{count}'
492      self.maybe_add_lines_handler(int(part['count']))
493    proto_type = handler_name
494    contents = CppHandlerCallExpr(handler=f'handle_{handler_name}',
495                                  field_name=field_name)
496    return proto_type, contents
497
498  def _default_handler(
499      self, part,
500      field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]:
501    proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']]
502    handler = DOMATO_TO_CPP_HANDLERS[part['tagname']]
503    contents = CppHandlerCallExpr(handler=handler, field_name=field_name)
504    return proto_type, contents
505
506  def _parse_rule(self, creator_name, rules):
507    messages = []
508    for rule_id, rule in enumerate(rules, start=1):
509      rule_msg_field_name = f'{to_proto_field_name(creator_name)}_{rule_id}'
510      proto_fields = []
511      cpp_contents = []
512      ret_vars = 0
513      for part_id, part in enumerate(rule['parts'], start=1):
514        field_name = f'{rule_msg_field_name}_{part_id}'
515        proto_type = None
516        if rule['type'] == 'code' and 'new' in part:
517          proto_fields.insert(
518              0,
519              ProtoField(type=ProtoType('optional int32'),
520                         name='old',
521                         proto_id=part_id))
522          ret_vars += 1
523          continue
524        if part['type'] == 'text':
525          contents = CppStringExpr(part['text'])
526        elif part['tagname'] == 'import':
527          # The current domato project is currently not handling that either in
528          # its built-in rules, and I do not plan on using the feature with
529          # newly written rules, as I think this directive has a lot of
530          # constraints with not much added value.
531          continue
532        elif part['tagname'] == 'call':
533          raise Exception(
534              'DomatoLPM does not implement <call> and <import> tags.')
535        elif part['tagname'] in self.grammar._constant_types.keys():
536          contents = CppStringExpr(
537              self.grammar._constant_types[part['tagname']])
538        elif part['tagname'] in self._built_in_types_parser:
539          handler = self._built_in_types_parser[part['tagname']]
540          proto_type, contents = handler(part, field_name)
541        elif part['type'] == 'tag':
542          proto_type = to_proto_type(part['tagname'])
543          contents = CppHandlerCallExpr(
544              handler=f'{CPP_HANDLER_PREFIX}{proto_type}',
545              field_name=field_name)
546        if proto_type:
547          proto_fields.append(
548              ProtoField(type=ProtoType(name=proto_type),
549                         name=field_name,
550                         proto_id=part_id))
551        cpp_contents.append(contents)
552
553      if ret_vars > 1:
554        raise Exception('Not implemented.')
555
556      creator = None
557      if rule['type'] == 'code' and ret_vars > 0:
558        creator = {'var_type': creator_name, 'var_prefix': 'var'}
559      proto_type = to_proto_type(creator_name)
560      rule_msg = ProtoMessage(name=f'{proto_type}_{rule_id}',
561                              fields=proto_fields)
562      rule_func = CppProtoMessageFunctionHandler(name=f'{proto_type}_{rule_id}',
563                                                 exprs=cpp_contents,
564                                                 creator=creator)
565
566      self._add(rule_msg, rule_func)
567      messages.append(rule_msg)
568    return messages
569
570  def _remove(self, name: str):
571    assert name in self.handlers
572    for field in self.handlers[name].msg.fields:
573      if field.type.name in self.backrefs:
574        self.backrefs[field.type.name].remove(name)
575    if name in self.backrefs:
576      self.backrefs.pop(name)
577    self.handlers.pop(name)
578
579  def _update(self, name: str):
580    assert name in self.handlers
581    for field in self.handlers[name].msg.fields:
582      if not field.type.name in self.backrefs:
583        self.backrefs[field.type.name] = []
584      self.backrefs[field.type.name].append(name)
585
586  def _count_backref(self, proto_name: str) -> int:
587    """Counts the number of backreference a given proto message has.
588
589    Args:
590        proto_name: the proto message name.
591
592    Returns:
593        the number of backreferences.
594    """
595    return len(self.backrefs[proto_name])
596
597  def _merge_proto_messages(self) -> bool:
598    """Merges messages referencing other messages into the same message. This
599    allows to tremendously reduce the number of protobuf messages that will be
600    generated.
601    """
602    to_merge = collections.defaultdict(set)
603    for name in self.handlers:
604      msg = self.handlers[name].msg
605      func = self.handlers[name].func
606      if msg.is_one_of() or not func.is_message_handler() or func.creates_new(
607      ) or self._is_root_node(name):
608        continue
609      if name not in self.backrefs:
610        continue
611      for elt in self.backrefs[name]:
612        if elt == name or elt not in self.handlers:
613          continue
614        if self.handlers[elt].msg.is_one_of():
615          continue
616        to_merge[elt].add(name)
617
618    for parent, childs in to_merge.items():
619      msg = self.handlers[parent].msg
620      fct = self.handlers[parent].func
621      for child in childs:
622        new_contents = []
623        for expr in fct.exprs:
624          if isinstance(expr, CppStringExpr):
625            new_contents.append(expr)
626            continue
627          assert isinstance(expr, CppHandlerCallExpr)
628          field: ProtoField = next(
629              (f for f in msg.fields if f.type.name == child), None)
630          if not field or not expr.field_name == field.name:
631            new_contents.append(expr)
632            continue
633          self.backrefs[field.type.name].remove(msg.name)
634          idx = msg.fields.index(field)
635          field_msg = self.handlers[child].msg
636          field_fct = self.handlers[child].func
637
638          # The following deepcopy is required because we might change the
639          # child's messages fields at some point, and we don't want those
640          # changes to affect this current's message fields.
641          fields_copy = copy.deepcopy(field_msg.fields)
642          msg.fields = msg.fields[:idx] + fields_copy + msg.fields[idx + 1:]
643          new_contents += copy.deepcopy(field_fct.exprs)
644          for f in field_msg.fields:
645            self.backrefs[f.type.name].append(msg.name)
646        fct.exprs = new_contents
647    return len(to_merge) > 0
648
649  def _message_renamer(self):
650    """Renames ProtoMessage fields that might have been merged. This ensures
651    proto field naming remains consistent with the current rule being
652    generated.
653    """
654    for entry in self.handlers.values():
655      if entry.msg.is_one_of() or entry.func.is_string_table_handler():
656        continue
657      for proto_id, field in enumerate(entry.msg.fields, start=1):
658        field.proto_id = proto_id
659        if entry.func.creates_new() and field.name == 'old':
660          continue
661        field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}')
662      index = 2 if entry.func.creates_new() else 1
663      new_contents = []
664      for expr in entry.func.exprs:
665        if not isinstance(expr, CppHandlerCallExpr):
666          new_contents.append(expr)
667          continue
668        new_contents.append(
669            CppHandlerCallExpr(expr.handler,
670                               to_proto_field_name(f'{entry.msg.name}_{index}'),
671                               expr.extra_args))
672        index += 1
673      entry.func.exprs = new_contents
674
675  def _oneof_message_renamer(self):
676    """Renames OneOfProtoMessage fields that might have been merged. This
677    ensures proto field naming remains consistent with the current rule being
678    generated.
679    """
680    for entry in self.handlers.values():
681      if not entry.msg.is_one_of():
682        continue
683      cases = {}
684      for proto_id, field in enumerate(entry.msg.fields, start=1):
685        field.proto_id = proto_id
686        exprs = entry.func.cases.pop(field.name)
687        field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}')
688        new_contents = []
689        for expr in exprs:
690          if not isinstance(expr, CppHandlerCallExpr):
691            new_contents.append(expr)
692            continue
693          new_contents.append(
694              CppHandlerCallExpr(expr.handler, field.name, expr.extra_args))
695        cases[field.name] = new_contents
696      entry.func.cases = cases
697
698  def _merge_multistrings_oneofs(self) -> bool:
699    """Merges multiple strings into a string table function."""
700    has_made_changes = False
701    for name in list(self.handlers.keys()):
702      msg = self.handlers[name].msg
703
704      if not msg.is_one_of():
705        continue
706
707      if not all(f.type.name in self.handlers and len(self.handlers[
708          f.type.name].msg.fields) == 0 and not self.handlers[f.type.name].msg.
709                 is_one_of() and len(self.handlers[f.type.name].func.exprs) == 1
710                 for f in msg.fields):
711        continue
712
713      fields = [ProtoField(type=ProtoType('uint32'), name='val', proto_id=1)]
714      new_msg = ProtoMessage(name=msg.name, fields=fields)
715      strings = []
716      for field in msg.fields:
717        self.backrefs[field.type.name].remove(name)
718        for expr in self.handlers[field.type.name].func.exprs:
719          assert isinstance(expr, CppStringExpr)
720          strings += [expr]
721      new_func = CppStringTableHandler(name=msg.name,
722                                       var_name='val',
723                                       strings=strings)
724      self.handlers[name] = DomatoBuilder.Entry(new_msg, new_func)
725      self._update(name)
726      has_made_changes = True
727    return has_made_changes
728
729  def _oneofs_reorderer(self):
730    """Reorders the OneOfProtoMessage so that the last element can be extracted
731    out of the protobuf oneof's field in order to always have a correct
732    path to be generated. This requires having at least one terminal path in
733    the grammar.
734    """
735    _terminal_messages = set()
736    _being_visited = set()
737
738    def recursive_terminal_marker(name: str):
739      if name in _terminal_messages or name not in self.handlers:
740        return True
741      if name in _being_visited:
742        return False
743      _being_visited.add(name)
744      msg = self.handlers[name].msg
745      func = self.handlers[name].func
746      if len(msg.fields) == 0:
747        _terminal_messages.add(name)
748        _being_visited.remove(name)
749        return True
750      if msg.is_one_of():
751        f = next(
752            (f for f in msg.fields if recursive_terminal_marker(f.type.name)),
753            None)
754        if not f:
755          #FIXME: for testing purpose only, we're not hard-failing on this.
756          _being_visited.remove(name)
757          return False
758        msg.fields.remove(f)
759        msg.fields.append(f)
760        m = next(k for k in func.cases.keys() if k == f.name)
761        func.cases[m] = func.cases.pop(m)
762        _terminal_messages.add(name)
763        _being_visited.remove(name)
764        return True
765      res = all(recursive_terminal_marker(f.type.name) for f in msg.fields)
766      #FIXME: for testing purpose only, we're not hard-failing on this.
767      _being_visited.remove(name)
768      return res
769
770    for name in self.handlers:
771      recursive_terminal_marker(name)
772
773  def _merge_oneofs(self) -> bool:
774    has_made_changes = False
775    for name in list(self.handlers.keys()):
776      msg = self.handlers[name].msg
777      func = self.handlers[name].func
778      if not msg.is_one_of():
779        continue
780
781      for field in msg.fields:
782        if not field.type.name in self.handlers:
783          continue
784        field_msg = self.handlers[field.type.name].msg
785        field_func = self.handlers[field.type.name].func
786        if field_msg.is_one_of() or len(
787            field_msg.fields) != 1 or not field_func.is_message_handler(
788            ) or field_func.creates_new():
789          continue
790        func.cases.pop(field.name)
791        field.name = field_msg.fields[0].name
792        field.type = field_msg.fields[0].type
793        while field.name in func.cases:
794          field.name += '_1'
795        func.cases[field.name] = copy.deepcopy(field_func.exprs)
796        self.backrefs[field_msg.name].remove(name)
797        self.backrefs[field.type.name].append(name)
798        has_made_changes = True
799    return has_made_changes
800
801  def _merge_unary_oneofs(self) -> bool:
802    """Transfors OneOfProtoMessage messages containing only one field into a
803    ProtoMessage containing the fields of the contained message. E.g.:
804        message B {
805          int field1 = 1;
806          Whatever field2 = 2;
807        }
808        message A {
809          oneof field {
810            B b = 1;
811          }
812        }
813        Into:
814        message A {
815          int field1 = 1;
816          Whatever field2 = 2;
817        }
818    """
819    has_made_changes = False
820    for name in list(self.handlers.keys()):
821      msg = self.handlers[name].msg
822      func = self.handlers[name].func
823
824      if not msg.is_one_of() or len(msg.fields) > 1:
825        continue
826
827      # The message is a unary oneof. Let's make sure it's only child doesn't
828      # have backrefs.
829      if self._count_backref(msg.fields[0].type.name) > 1:
830        continue
831
832      # The only backref should really only be us. If not we screwed up
833      # somewhere else.
834      assert name in self.backrefs[msg.fields[0].type.name]
835      field_msg: ProtoMessage = self.handlers[msg.fields[0].type.name].msg
836      if field_msg.is_one_of():
837        continue
838
839      field_func = self.handlers[msg.fields[0].type.name].func
840      self._remove(msg.fields[0].type.name)
841      msg = ProtoMessage(name=msg.name, fields=field_msg.fields)
842      func = CppProtoMessageFunctionHandler(name=msg.name,
843                                            exprs=field_func.exprs,
844                                            creator=field_func.creator)
845      self.handlers[name] = DomatoBuilder.Entry(msg, func)
846      self._update(name)
847      has_made_changes = True
848    return has_made_changes
849
850  def _merge_strings(self) -> bool:
851    """Merges following CppString, e.g.
852    [ CppString("<first>"), CppString("<second>")]
853    Into:
854    [ CppString("<first><second>")]
855    """
856    has_made_changes = False
857    for name in self.handlers:
858      func: CppFunctionHandler = self.handlers[name].func
859      if not func.is_message_handler() or len(func.exprs) <= 1:
860        continue
861
862      exprs = []
863      prev = func.exprs[0]
864      for i in range(1, len(func.exprs)):
865        cur = func.exprs[i]
866        if isinstance(prev, CppStringExpr) and isinstance(cur, CppStringExpr):
867          cur = CppStringExpr(prev.content + cur.content)
868          has_made_changes = True
869        else:
870          exprs.append(prev)
871        prev = cur
872      exprs.append(prev)
873      func.exprs = exprs
874    return has_made_changes
875
876  def _is_root_node(self, name: str):
877    # If there is no existing root, we set it to `lines`, since this will
878    # be picked as the default root.
879    if 'line' not in self.root:
880      return self.root == name
881    return re.match('^line(s)?(_[0-9]*)?$', name) is not None
882
883  def _remove_unlinked_nodes(self) -> bool:
884    """Removes proto messages that are neither part of the root definition nor
885    referenced by any other messages. This can happen during other optimization
886    functions.
887
888    Returns:
889        whether a change was made.
890    """
891    to_remove = set()
892    for name in self.handlers:
893      if name not in self.backrefs or len(self.backrefs[name]) == 0:
894        if not self._is_root_node(name):
895          to_remove.add(name)
896    local_root = 'line' if self.should_generate_one_line_handler(
897    ) else self.root
898    seen = set()
899
900    def visit_msg(msg: ProtoMessage):
901      if msg.name in seen:
902        return
903      seen.add(msg.name)
904      for field in msg.fields:
905        if field.type.name in self.handlers:
906          visit_msg(self.handlers[field.type.name].msg)
907
908    visit_msg(self.handlers[local_root].msg)
909    not_seen = set(self.handlers.keys()) - seen
910    to_remove.update(set(filter(lambda x: not self._is_root_node(x), not_seen)))
911    for t in to_remove:
912      self._remove(t)
913    return len(to_remove) > 0
914
915
916def _render_internal(template: jinja2.Template,
917                     context: typing.Dict[str, typing.Any], out_f: str):
918  with action_helpers.atomic_output(out_f, mode='w') as f:
919    f.write(template.render(context))
920
921
922def _render_proto_internal(
923    template: jinja2.Template, out_f: str,
924    proto_messages: typing.List[typing.Union[ProtoMessage, OneOfProtoMessage]],
925    should_generate_repeated_lines: bool, proto_ns: str,
926    imports: typing.List[str]):
927  _render_internal(template, {
928      'messages': [m for m in proto_messages if not m.is_one_of()],
929      'oneofmessages': [m for m in proto_messages if m.is_one_of()],
930      'generate_repeated_lines': should_generate_repeated_lines,
931      'proto_ns': proto_ns,
932      'imports': imports,
933  },
934                   out_f=out_f)
935
936
937def render_proto(environment: jinja2.Environment, generated_dir: str,
938                 out_f: str, name: str, builder: DomatoBuilder):
939  template = environment.get_template('domatolpm.proto.tmpl')
940  roots, non_roots = builder.get_protos()
941  ns = f'{BASE_PROTO_NS}.{name}'
942  sub_proto_filename = pathlib.PurePosixPath(f'{out_f}_sub.proto').name
943  import_path = pathlib.PurePosixPath(generated_dir).joinpath(
944      sub_proto_filename)
945  _render_proto_internal(template, f'{out_f}.proto', roots,
946                         builder.should_generate_repeated_lines(), ns,
947                         [str(import_path)])
948  _render_proto_internal(template, f'{out_f}_sub.proto', non_roots, False, ns,
949                         [])
950
951
952def render_cpp(environment: jinja2.Environment, out_f: str, name: str,
953               builder: DomatoBuilder):
954  functions = builder.all_cpp_functions()
955  funcs = [f for f in functions if f.is_message_handler()]
956  oneofs = [f for f in functions if f.is_oneof_handler()]
957  stfunctions = [f for f in functions if f.is_string_table_handler()]
958  _, root_func = builder.get_roots()
959
960  rendering_context = {
961      'basename': os.path.basename(out_f),
962      'functions': funcs,
963      'oneoffunctions': oneofs,
964      'stfunctions': stfunctions,
965      'root': root_func,
966      'generate_repeated_lines': builder.should_generate_repeated_lines(),
967      'generate_one_line_handler': builder.should_generate_one_line_handler(),
968      'line_prefix': builder.get_line_prefix(),
969      'line_suffix': builder.get_line_suffix(),
970      'proto_ns': to_cpp_ns(f'{BASE_PROTO_NS}.{name}'),
971      'cpp_ns': f'domatolpm::{name}',
972  }
973  template = environment.get_template('domatolpm.cc.tmpl')
974  _render_internal(template, rendering_context, f'{out_f}.cc')
975  template = environment.get_template('domatolpm.h.tmpl')
976  _render_internal(template, rendering_context, f'{out_f}.h')
977
978
979def main():
980  parser = argparse.ArgumentParser(
981      description=
982      'Generate the necessary files for DomatoLPM to function properly.')
983  parser.add_argument('-p',
984                      '--path',
985                      required=True,
986                      help='The path to a Domato grammar file.')
987  parser.add_argument('-n',
988                      '--name',
989                      required=True,
990                      help='The name of this grammar.')
991  parser.add_argument(
992      '-f',
993      '--file-format',
994      required=True,
995      help='The path prefix to which the files should be generated.')
996  parser.add_argument('-d',
997                      '--generated-dir',
998                      required=True,
999                      help='The path to the target gen directory.')
1000
1001  args = parser.parse_args()
1002  g = grammar.Grammar()
1003  g.parse_from_file(filename=args.path)
1004
1005  template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)),
1006                              'templates')
1007  environment = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir))
1008  builder = DomatoBuilder(g)
1009  builder.parse_grammar()
1010  builder.simplify()
1011  render_cpp(environment, args.file_format, args.name, builder)
1012  render_proto(environment, args.generated_dir, args.file_format, args.name,
1013               builder)
1014
1015
1016if __name__ == '__main__':
1017  main()
1018