#!/usr/bin/env python3 # # Copyright 2024 The Chromium Authors # Use of this source code is governed by a BSD-style license that can be # found in the LICENSE file. import argparse import collections import copy import os import pathlib import sys import typing import re import dataclasses def _GetDirAbove(dirname: str): """Returns the directory "above" this file containing |dirname| (which must also be "above" this file).""" path = os.path.abspath(__file__) while True: path, tail = os.path.split(path) if not tail: return None if tail == dirname: return path SOURCE_DIR = _GetDirAbove('testing') sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party')) sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party/domato/src')) sys.path.append(os.path.join(SOURCE_DIR, 'build')) import action_helpers import jinja2 import grammar # TODO(crbug.com/361369290): Remove this disable once DomatoLPM development is # finished and upstream changes can be made to expose the relevant protected # fields. # pylint: disable=protected-access def to_snake_case(name): name = re.sub(r'([A-Z]{2,})([A-Z][a-z])', r'\1_\2', name) return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', name, sys.maxsize).lower() DOMATO_INT_TYPE_TO_CPP_INT_TYPE = { 'int': 'int', 'int32': 'int32_t', 'uint32': 'uint32_t', 'int8': 'int8_t', 'uint8': 'uint8_t', 'int16': 'int16_t', 'uint16': 'uint16_t', 'int64': 'uint64_t', 'uint64': 'uint64_t', } DOMATO_TO_PROTO_BUILT_IN = { 'int': 'int32', 'int32': 'int32', 'uint32': 'uint32', 'int8': 'int32', 'uint8': 'uint32', 'int16': 'int16', 'uint16': 'uint16', 'int64': 'int64', 'uint64': 'uint64', 'float': 'float', 'double': 'double', 'char': 'int32', 'string': 'string', 'htmlsafestring': 'string', 'hex': 'int32', 'lines': 'repeated lines', } DOMATO_TO_CPP_HANDLERS = { 'int': 'handle_int_conversion', 'int32': 'handle_int_conversion', 'uint32': 'handle_int_conversion', 'int8': 'handle_int_conversion', 'uint8': 'handle_int_conversion', 'int16': 'handle_int_conversion', 'uint16': 'handle_int_conversion', 'int64': 'handle_int_conversion', 'uint64': 'handle_int_conversion', 'float': 'handle_float', 'double': 'handle_double', 'char': 'handle_char', 'string': 'handle_string', 'htmlsafestring': 'handle_string', 'hex': 'handle_hex', } _C_STR_TRANS = str.maketrans({ '\n': '\\n', '\r': '\\r', '\t': '\\t', '\"': '\\\"', '\\': '\\\\' }) BASE_PROTO_NS = 'domatolpm.generated' def to_cpp_ns(proto_ns: str) -> str: return proto_ns.replace('.', '::') CPP_HANDLER_PREFIX = 'handle_' def to_proto_field_name(name: str) -> str: """Converts a creator or rule name to a proto field name. This tries to respect the protobuf naming convention that field names should be snake case. Args: name: the name of the creator or the rule. Returns: the proto field name to use. """ res = to_snake_case(name.replace('-', '_')) if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']: res += '_proto' return res def to_proto_type(creator_name: str) -> str: """Converts a creator name to a proto type. This is deliberately very simple so that we avoid naming conflicts. Args: creator_name: the name of the creator. Returns: the name of the proto type. """ res = creator_name.replace('-', '_') if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']: res += '_proto' return res def c_escape(v: str) -> str: return v.translate(_C_STR_TRANS) @dataclasses.dataclass class ProtoType: """Represents a Proto type.""" name: str def is_one_of(self) -> bool: return False @dataclasses.dataclass class ProtoField: """Represents a proto message field.""" type: ProtoType name: str proto_id: int @dataclasses.dataclass class ProtoMessage(ProtoType): """Represents a Proto message.""" fields: typing.List[ProtoField] @dataclasses.dataclass class OneOfProtoMessage(ProtoMessage): """Represents a Proto message with a oneof field.""" oneofname: str def is_one_of(self) -> bool: return True class CppExpression: def repr(self): raise Exception('Not implemented.') @dataclasses.dataclass class CppTxtExpression(CppExpression): """Represents a Raw text expression.""" content: str def repr(self): return self.content @dataclasses.dataclass class CppCallExpr(CppExpression): """Represents a CallExpr.""" fct_name: str args: typing.List[CppExpression] ns: str = '' def repr(self): arg_s = ', '.join([a.repr() for a in self.args]) return f'{self.ns}{self.fct_name}({arg_s})' class CppHandlerCallExpr(CppCallExpr): def __init__(self, handler: str, field_name: str, extra_args: typing.Optional[typing.List[CppExpression]] = None): args = [CppTxtExpression('ctx'), CppTxtExpression(f'arg.{field_name}()')] if extra_args: args += extra_args super().__init__(fct_name=handler, args=args) self.handler = handler self.field_name = field_name self.extra_args = extra_args @dataclasses.dataclass class CppStringExpr(CppExpression): """Represents a C++ literal string. """ content: str def repr(self): return f'\"{c_escape(self.content)}\"' @dataclasses.dataclass class CppFunctionHandler: """Represents a C++ function. """ name: str exprs: typing.List[CppExpression] def is_oneof_handler(self) -> bool: return False def is_string_table_handler(self) -> bool: return False def is_message_handler(self) -> bool: return False class CppStringTableHandler(CppFunctionHandler): """Represents a C++ function that implements a string table and returns one of the represented strings. """ def __init__(self, name: str, var_name: str, strings: typing.List[CppStringExpr]): super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[]) self.proto_type = f'{name}& arg' self.strings = strings self.var_name = var_name def is_string_table_handler(self) -> bool: return True class CppProtoMessageFunctionHandler(CppFunctionHandler): """Represents a C++ function that handles a ProtoMessage. """ def __init__(self, name: str, exprs: typing.List[CppExpression], creator: typing.Optional[typing.Dict[str, str]] = None): super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=exprs) self.proto_type = f'{name}& arg' self.creator = creator def creates_new(self): return self.creator is not None def is_message_handler(self) -> bool: return True class CppOneOfMessageFunctionHandler(CppFunctionHandler): """Represents a C++ function that handles a OneOfProtoMessage. """ def __init__(self, name: str, switch_name: str, cases: typing.Dict[str, typing.List[CppExpression]]): super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[]) self.proto_type = f'{name}& arg' self.switch_name = switch_name self.cases = cases def all_except_last(self): a = list(self.cases.keys())[:-1] return {e: self.cases[e] for e in a} def last(self): a = list(self.cases.keys())[-1] return self.cases[a] def is_oneof_handler(self) -> bool: return True class DomatoBuilder: """DomatoBuilder is the class that takes a Domato grammar, and modelize it into a protobuf representation and its corresponding C++ parsing code. """ @dataclasses.dataclass class Entry: msg: ProtoMessage func: CppFunctionHandler def __init__(self, g: grammar.Grammar): self.handlers: typing.Dict[str, DomatoBuilder.Entry] = {} self.backrefs: typing.Dict[str, typing.List[str]] = {} self.grammar = g if self.grammar._root and self.grammar._root != 'root': self.root = self.grammar._root else: self.root = 'lines' if self.grammar._root and self.grammar._root == 'root': rules = self.grammar._creators[self.grammar._root] # multiple roots doesn't make sense, so we only consider the last defined # one. rule = rules[-1] for part in rule['parts']: if part['type'] == 'tag' and part[ 'tagname'] == 'lines' and 'count' in part: self.root = f'lines_{part["count"]}' break self._built_in_types_parser = { 'int': self._int_handler, 'int32': self._int_handler, 'uint32': self._int_handler, 'int8': self._int_handler, 'uint8': self._int_handler, 'int16': self._int_handler, 'uint16': self._int_handler, 'int64': self._int_handler, 'uint64': self._int_handler, 'float': self._default_handler, 'double': self._default_handler, 'char': self._default_handler, 'string': self._default_handler, 'htmlsafestring': self._default_handler, 'hex': self._default_handler, 'lines': self._lines_handler, } def parse_grammar(self): for creator, rules in self.grammar._creators.items(): field_name = to_proto_field_name(creator) type_name = to_proto_type(creator) messages = self._parse_rule(creator, rules) proto_fields: typing.List[ProtoField] = [] for proto_id, msg in enumerate(messages, start=1): proto_fields.append( ProtoField(type=ProtoType(name=msg.name), name=f'{field_name}_{proto_id}', proto_id=proto_id)) msg = OneOfProtoMessage(name=type_name, oneofname='oneoffield', fields=proto_fields) cases = { f.name: [ CppHandlerCallExpr(handler=f'{CPP_HANDLER_PREFIX}{f.type.name}', field_name=f.name) ] for f in proto_fields } func = CppOneOfMessageFunctionHandler(name=type_name, switch_name='oneoffield', cases=cases) self._add(msg, func) def all_proto_messages(self): return [v.msg for v in self.handlers.values()] def all_cpp_functions(self): return [v.func for v in self.handlers.values()] def get_line_prefix(self) -> str: if not self.grammar._line_guard: return '' return self.grammar._line_guard.split('')[0] def get_line_suffix(self) -> str: if not self.grammar._line_guard: return '' return self.grammar._line_guard.split('')[1] def should_generate_repeated_lines(self): return self.root == 'lines' def should_generate_one_line_handler(self): return self.root.startswith('lines') def maybe_add_lines_handler(self, number: int) -> bool: name = f'lines_{number}' if name in self.handlers: return False fields = [] exprs = [] for i in range(1, number + 1): fields.append(ProtoField(ProtoType('line'), f'line_{i}', i)) exprs.append(CppHandlerCallExpr('handle_one_line', f'line_{i}')) msg = ProtoMessage(name, fields=fields) handler = CppProtoMessageFunctionHandler(name, exprs=exprs) self.handlers[name] = DomatoBuilder.Entry(msg, handler) return True def get_roots(self) -> typing.Tuple[ProtoMessage, CppFunctionHandler]: root = self.root root_handler = f'{CPP_HANDLER_PREFIX}{root}' fuzz_case = ProtoMessage( name='fuzzcase', fields=[ProtoField(type=ProtoType(name=root), name='root', proto_id=1)]) fuzz_fct = CppProtoMessageFunctionHandler( name='fuzzcase', exprs=[CppHandlerCallExpr(handler=root_handler, field_name='root')]) return fuzz_case, fuzz_fct def get_protos(self) -> typing.Tuple[typing.List[ProtoMessage]]: if self.should_generate_one_line_handler(): # We're handling a code grammar. roots = [v.msg for k, v in self.handlers.items() if k.startswith('line')] roots.append(self.get_roots()[0]) non_roots = [ v.msg for k, v in self.handlers.items() if not k.startswith('line') ] return roots, non_roots return [self.get_roots()[0]], self.all_proto_messages() def simplify(self): """Simplifies the proto and functions.""" should_continue = True while should_continue: should_continue = False should_continue |= self._merge_unary_oneofs() should_continue |= self._merge_strings() should_continue |= self._merge_multistrings_oneofs() should_continue |= self._remove_unlinked_nodes() should_continue |= self._merge_proto_messages() should_continue |= self._merge_oneofs() self._oneofs_reorderer() self._oneof_message_renamer() self._message_renamer() def _add(self, message: ProtoMessage, handler: CppProtoMessageFunctionHandler): self.handlers[message.name] = DomatoBuilder.Entry(message, handler) for field in message.fields: if not field.type.name in self.backrefs: self.backrefs[field.type.name] = [] self.backrefs[field.type.name].append(message.name) def _int_handler( self, part, field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']] handler = DOMATO_TO_CPP_HANDLERS[part['tagname']] extra_args = [] if 'min' in part: extra_args.append(CppTxtExpression(part['min'])) if 'max' in part: if not extra_args: cpp_type = DOMATO_INT_TYPE_TO_CPP_INT_TYPE[part['tagname']] extra_args.append( CppTxtExpression(f'std::numeric_limits<{cpp_type}>::min()')) extra_args.append(CppTxtExpression(part['max'])) contents = CppHandlerCallExpr(handler=handler, field_name=field_name, extra_args=extra_args) return proto_type, contents def _lines_handler( self, part, field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: handler_name = 'lines' if 'count' in part: count = part['count'] handler_name = f'{handler_name}_{count}' self.maybe_add_lines_handler(int(part['count'])) proto_type = handler_name contents = CppHandlerCallExpr(handler=f'handle_{handler_name}', field_name=field_name) return proto_type, contents def _default_handler( self, part, field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']] handler = DOMATO_TO_CPP_HANDLERS[part['tagname']] contents = CppHandlerCallExpr(handler=handler, field_name=field_name) return proto_type, contents def _parse_rule(self, creator_name, rules): messages = [] for rule_id, rule in enumerate(rules, start=1): rule_msg_field_name = f'{to_proto_field_name(creator_name)}_{rule_id}' proto_fields = [] cpp_contents = [] ret_vars = 0 for part_id, part in enumerate(rule['parts'], start=1): field_name = f'{rule_msg_field_name}_{part_id}' proto_type = None if rule['type'] == 'code' and 'new' in part: proto_fields.insert( 0, ProtoField(type=ProtoType('optional int32'), name='old', proto_id=part_id)) ret_vars += 1 continue if part['type'] == 'text': contents = CppStringExpr(part['text']) elif part['tagname'] == 'import': # The current domato project is currently not handling that either in # its built-in rules, and I do not plan on using the feature with # newly written rules, as I think this directive has a lot of # constraints with not much added value. continue elif part['tagname'] == 'call': raise Exception( 'DomatoLPM does not implement and tags.') elif part['tagname'] in self.grammar._constant_types.keys(): contents = CppStringExpr( self.grammar._constant_types[part['tagname']]) elif part['tagname'] in self._built_in_types_parser: handler = self._built_in_types_parser[part['tagname']] proto_type, contents = handler(part, field_name) elif part['type'] == 'tag': proto_type = to_proto_type(part['tagname']) contents = CppHandlerCallExpr( handler=f'{CPP_HANDLER_PREFIX}{proto_type}', field_name=field_name) if proto_type: proto_fields.append( ProtoField(type=ProtoType(name=proto_type), name=field_name, proto_id=part_id)) cpp_contents.append(contents) if ret_vars > 1: raise Exception('Not implemented.') creator = None if rule['type'] == 'code' and ret_vars > 0: creator = {'var_type': creator_name, 'var_prefix': 'var'} proto_type = to_proto_type(creator_name) rule_msg = ProtoMessage(name=f'{proto_type}_{rule_id}', fields=proto_fields) rule_func = CppProtoMessageFunctionHandler(name=f'{proto_type}_{rule_id}', exprs=cpp_contents, creator=creator) self._add(rule_msg, rule_func) messages.append(rule_msg) return messages def _remove(self, name: str): assert name in self.handlers for field in self.handlers[name].msg.fields: if field.type.name in self.backrefs: self.backrefs[field.type.name].remove(name) if name in self.backrefs: self.backrefs.pop(name) self.handlers.pop(name) def _update(self, name: str): assert name in self.handlers for field in self.handlers[name].msg.fields: if not field.type.name in self.backrefs: self.backrefs[field.type.name] = [] self.backrefs[field.type.name].append(name) def _count_backref(self, proto_name: str) -> int: """Counts the number of backreference a given proto message has. Args: proto_name: the proto message name. Returns: the number of backreferences. """ return len(self.backrefs[proto_name]) def _merge_proto_messages(self) -> bool: """Merges messages referencing other messages into the same message. This allows to tremendously reduce the number of protobuf messages that will be generated. """ to_merge = collections.defaultdict(set) for name in self.handlers: msg = self.handlers[name].msg func = self.handlers[name].func if msg.is_one_of() or not func.is_message_handler() or func.creates_new( ) or self._is_root_node(name): continue if name not in self.backrefs: continue for elt in self.backrefs[name]: if elt == name or elt not in self.handlers: continue if self.handlers[elt].msg.is_one_of(): continue to_merge[elt].add(name) for parent, childs in to_merge.items(): msg = self.handlers[parent].msg fct = self.handlers[parent].func for child in childs: new_contents = [] for expr in fct.exprs: if isinstance(expr, CppStringExpr): new_contents.append(expr) continue assert isinstance(expr, CppHandlerCallExpr) field: ProtoField = next( (f for f in msg.fields if f.type.name == child), None) if not field or not expr.field_name == field.name: new_contents.append(expr) continue self.backrefs[field.type.name].remove(msg.name) idx = msg.fields.index(field) field_msg = self.handlers[child].msg field_fct = self.handlers[child].func # The following deepcopy is required because we might change the # child's messages fields at some point, and we don't want those # changes to affect this current's message fields. fields_copy = copy.deepcopy(field_msg.fields) msg.fields = msg.fields[:idx] + fields_copy + msg.fields[idx + 1:] new_contents += copy.deepcopy(field_fct.exprs) for f in field_msg.fields: self.backrefs[f.type.name].append(msg.name) fct.exprs = new_contents return len(to_merge) > 0 def _message_renamer(self): """Renames ProtoMessage fields that might have been merged. This ensures proto field naming remains consistent with the current rule being generated. """ for entry in self.handlers.values(): if entry.msg.is_one_of() or entry.func.is_string_table_handler(): continue for proto_id, field in enumerate(entry.msg.fields, start=1): field.proto_id = proto_id if entry.func.creates_new() and field.name == 'old': continue field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}') index = 2 if entry.func.creates_new() else 1 new_contents = [] for expr in entry.func.exprs: if not isinstance(expr, CppHandlerCallExpr): new_contents.append(expr) continue new_contents.append( CppHandlerCallExpr(expr.handler, to_proto_field_name(f'{entry.msg.name}_{index}'), expr.extra_args)) index += 1 entry.func.exprs = new_contents def _oneof_message_renamer(self): """Renames OneOfProtoMessage fields that might have been merged. This ensures proto field naming remains consistent with the current rule being generated. """ for entry in self.handlers.values(): if not entry.msg.is_one_of(): continue cases = {} for proto_id, field in enumerate(entry.msg.fields, start=1): field.proto_id = proto_id exprs = entry.func.cases.pop(field.name) field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}') new_contents = [] for expr in exprs: if not isinstance(expr, CppHandlerCallExpr): new_contents.append(expr) continue new_contents.append( CppHandlerCallExpr(expr.handler, field.name, expr.extra_args)) cases[field.name] = new_contents entry.func.cases = cases def _merge_multistrings_oneofs(self) -> bool: """Merges multiple strings into a string table function.""" has_made_changes = False for name in list(self.handlers.keys()): msg = self.handlers[name].msg if not msg.is_one_of(): continue if not all(f.type.name in self.handlers and len(self.handlers[ f.type.name].msg.fields) == 0 and not self.handlers[f.type.name].msg. is_one_of() and len(self.handlers[f.type.name].func.exprs) == 1 for f in msg.fields): continue fields = [ProtoField(type=ProtoType('uint32'), name='val', proto_id=1)] new_msg = ProtoMessage(name=msg.name, fields=fields) strings = [] for field in msg.fields: self.backrefs[field.type.name].remove(name) for expr in self.handlers[field.type.name].func.exprs: assert isinstance(expr, CppStringExpr) strings += [expr] new_func = CppStringTableHandler(name=msg.name, var_name='val', strings=strings) self.handlers[name] = DomatoBuilder.Entry(new_msg, new_func) self._update(name) has_made_changes = True return has_made_changes def _oneofs_reorderer(self): """Reorders the OneOfProtoMessage so that the last element can be extracted out of the protobuf oneof's field in order to always have a correct path to be generated. This requires having at least one terminal path in the grammar. """ _terminal_messages = set() _being_visited = set() def recursive_terminal_marker(name: str): if name in _terminal_messages or name not in self.handlers: return True if name in _being_visited: return False _being_visited.add(name) msg = self.handlers[name].msg func = self.handlers[name].func if len(msg.fields) == 0: _terminal_messages.add(name) _being_visited.remove(name) return True if msg.is_one_of(): f = next( (f for f in msg.fields if recursive_terminal_marker(f.type.name)), None) if not f: #FIXME: for testing purpose only, we're not hard-failing on this. _being_visited.remove(name) return False msg.fields.remove(f) msg.fields.append(f) m = next(k for k in func.cases.keys() if k == f.name) func.cases[m] = func.cases.pop(m) _terminal_messages.add(name) _being_visited.remove(name) return True res = all(recursive_terminal_marker(f.type.name) for f in msg.fields) #FIXME: for testing purpose only, we're not hard-failing on this. _being_visited.remove(name) return res for name in self.handlers: recursive_terminal_marker(name) def _merge_oneofs(self) -> bool: has_made_changes = False for name in list(self.handlers.keys()): msg = self.handlers[name].msg func = self.handlers[name].func if not msg.is_one_of(): continue for field in msg.fields: if not field.type.name in self.handlers: continue field_msg = self.handlers[field.type.name].msg field_func = self.handlers[field.type.name].func if field_msg.is_one_of() or len( field_msg.fields) != 1 or not field_func.is_message_handler( ) or field_func.creates_new(): continue func.cases.pop(field.name) field.name = field_msg.fields[0].name field.type = field_msg.fields[0].type while field.name in func.cases: field.name += '_1' func.cases[field.name] = copy.deepcopy(field_func.exprs) self.backrefs[field_msg.name].remove(name) self.backrefs[field.type.name].append(name) has_made_changes = True return has_made_changes def _merge_unary_oneofs(self) -> bool: """Transfors OneOfProtoMessage messages containing only one field into a ProtoMessage containing the fields of the contained message. E.g.: message B { int field1 = 1; Whatever field2 = 2; } message A { oneof field { B b = 1; } } Into: message A { int field1 = 1; Whatever field2 = 2; } """ has_made_changes = False for name in list(self.handlers.keys()): msg = self.handlers[name].msg func = self.handlers[name].func if not msg.is_one_of() or len(msg.fields) > 1: continue # The message is a unary oneof. Let's make sure it's only child doesn't # have backrefs. if self._count_backref(msg.fields[0].type.name) > 1: continue # The only backref should really only be us. If not we screwed up # somewhere else. assert name in self.backrefs[msg.fields[0].type.name] field_msg: ProtoMessage = self.handlers[msg.fields[0].type.name].msg if field_msg.is_one_of(): continue field_func = self.handlers[msg.fields[0].type.name].func self._remove(msg.fields[0].type.name) msg = ProtoMessage(name=msg.name, fields=field_msg.fields) func = CppProtoMessageFunctionHandler(name=msg.name, exprs=field_func.exprs, creator=field_func.creator) self.handlers[name] = DomatoBuilder.Entry(msg, func) self._update(name) has_made_changes = True return has_made_changes def _merge_strings(self) -> bool: """Merges following CppString, e.g. [ CppString(""), CppString("")] Into: [ CppString("")] """ has_made_changes = False for name in self.handlers: func: CppFunctionHandler = self.handlers[name].func if not func.is_message_handler() or len(func.exprs) <= 1: continue exprs = [] prev = func.exprs[0] for i in range(1, len(func.exprs)): cur = func.exprs[i] if isinstance(prev, CppStringExpr) and isinstance(cur, CppStringExpr): cur = CppStringExpr(prev.content + cur.content) has_made_changes = True else: exprs.append(prev) prev = cur exprs.append(prev) func.exprs = exprs return has_made_changes def _is_root_node(self, name: str): # If there is no existing root, we set it to `lines`, since this will # be picked as the default root. if 'line' not in self.root: return self.root == name return re.match('^line(s)?(_[0-9]*)?$', name) is not None def _remove_unlinked_nodes(self) -> bool: """Removes proto messages that are neither part of the root definition nor referenced by any other messages. This can happen during other optimization functions. Returns: whether a change was made. """ to_remove = set() for name in self.handlers: if name not in self.backrefs or len(self.backrefs[name]) == 0: if not self._is_root_node(name): to_remove.add(name) local_root = 'line' if self.should_generate_one_line_handler( ) else self.root seen = set() def visit_msg(msg: ProtoMessage): if msg.name in seen: return seen.add(msg.name) for field in msg.fields: if field.type.name in self.handlers: visit_msg(self.handlers[field.type.name].msg) visit_msg(self.handlers[local_root].msg) not_seen = set(self.handlers.keys()) - seen to_remove.update(set(filter(lambda x: not self._is_root_node(x), not_seen))) for t in to_remove: self._remove(t) return len(to_remove) > 0 def _render_internal(template: jinja2.Template, context: typing.Dict[str, typing.Any], out_f: str): with action_helpers.atomic_output(out_f, mode='w') as f: f.write(template.render(context)) def _render_proto_internal( template: jinja2.Template, out_f: str, proto_messages: typing.List[typing.Union[ProtoMessage, OneOfProtoMessage]], should_generate_repeated_lines: bool, proto_ns: str, imports: typing.List[str]): _render_internal(template, { 'messages': [m for m in proto_messages if not m.is_one_of()], 'oneofmessages': [m for m in proto_messages if m.is_one_of()], 'generate_repeated_lines': should_generate_repeated_lines, 'proto_ns': proto_ns, 'imports': imports, }, out_f=out_f) def render_proto(environment: jinja2.Environment, generated_dir: str, out_f: str, name: str, builder: DomatoBuilder): template = environment.get_template('domatolpm.proto.tmpl') roots, non_roots = builder.get_protos() ns = f'{BASE_PROTO_NS}.{name}' sub_proto_filename = pathlib.PurePosixPath(f'{out_f}_sub.proto').name import_path = pathlib.PurePosixPath(generated_dir).joinpath( sub_proto_filename) _render_proto_internal(template, f'{out_f}.proto', roots, builder.should_generate_repeated_lines(), ns, [str(import_path)]) _render_proto_internal(template, f'{out_f}_sub.proto', non_roots, False, ns, []) def render_cpp(environment: jinja2.Environment, out_f: str, name: str, builder: DomatoBuilder): functions = builder.all_cpp_functions() funcs = [f for f in functions if f.is_message_handler()] oneofs = [f for f in functions if f.is_oneof_handler()] stfunctions = [f for f in functions if f.is_string_table_handler()] _, root_func = builder.get_roots() rendering_context = { 'basename': os.path.basename(out_f), 'functions': funcs, 'oneoffunctions': oneofs, 'stfunctions': stfunctions, 'root': root_func, 'generate_repeated_lines': builder.should_generate_repeated_lines(), 'generate_one_line_handler': builder.should_generate_one_line_handler(), 'line_prefix': builder.get_line_prefix(), 'line_suffix': builder.get_line_suffix(), 'proto_ns': to_cpp_ns(f'{BASE_PROTO_NS}.{name}'), 'cpp_ns': f'domatolpm::{name}', } template = environment.get_template('domatolpm.cc.tmpl') _render_internal(template, rendering_context, f'{out_f}.cc') template = environment.get_template('domatolpm.h.tmpl') _render_internal(template, rendering_context, f'{out_f}.h') def main(): parser = argparse.ArgumentParser( description= 'Generate the necessary files for DomatoLPM to function properly.') parser.add_argument('-p', '--path', required=True, help='The path to a Domato grammar file.') parser.add_argument('-n', '--name', required=True, help='The name of this grammar.') parser.add_argument( '-f', '--file-format', required=True, help='The path prefix to which the files should be generated.') parser.add_argument('-d', '--generated-dir', required=True, help='The path to the target gen directory.') args = parser.parse_args() g = grammar.Grammar() g.parse_from_file(filename=args.path) template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'templates') environment = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir)) builder = DomatoBuilder(g) builder.parse_grammar() builder.simplify() render_cpp(environment, args.file_format, args.name, builder) render_proto(environment, args.generated_dir, args.file_format, args.name, builder) if __name__ == '__main__': main()