1#!/usr/bin/env python3 2# 3# Copyright 2024 The Chromium Authors 4# Use of this source code is governed by a BSD-style license that can be 5# found in the LICENSE file. 6 7import argparse 8import collections 9import copy 10import os 11import pathlib 12import sys 13import typing 14import re 15import dataclasses 16 17 18def _GetDirAbove(dirname: str): 19 """Returns the directory "above" this file containing |dirname| (which must 20 also be "above" this file).""" 21 path = os.path.abspath(__file__) 22 while True: 23 path, tail = os.path.split(path) 24 if not tail: 25 return None 26 if tail == dirname: 27 return path 28 29 30SOURCE_DIR = _GetDirAbove('testing') 31 32sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party')) 33sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party/domato/src')) 34sys.path.append(os.path.join(SOURCE_DIR, 'build')) 35 36import action_helpers 37import jinja2 38import grammar 39 40# TODO(crbug.com/361369290): Remove this disable once DomatoLPM development is 41# finished and upstream changes can be made to expose the relevant protected 42# fields. 43# pylint: disable=protected-access 44 45def to_snake_case(name): 46 name = re.sub(r'([A-Z]{2,})([A-Z][a-z])', r'\1_\2', name) 47 return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', name, sys.maxsize).lower() 48 49 50DOMATO_INT_TYPE_TO_CPP_INT_TYPE = { 51 'int': 'int', 52 'int32': 'int32_t', 53 'uint32': 'uint32_t', 54 'int8': 'int8_t', 55 'uint8': 'uint8_t', 56 'int16': 'int16_t', 57 'uint16': 'uint16_t', 58 'int64': 'uint64_t', 59 'uint64': 'uint64_t', 60} 61 62DOMATO_TO_PROTO_BUILT_IN = { 63 'int': 'int32', 64 'int32': 'int32', 65 'uint32': 'uint32', 66 'int8': 'int32', 67 'uint8': 'uint32', 68 'int16': 'int16', 69 'uint16': 'uint16', 70 'int64': 'int64', 71 'uint64': 'uint64', 72 'float': 'float', 73 'double': 'double', 74 'char': 'int32', 75 'string': 'string', 76 'htmlsafestring': 'string', 77 'hex': 'int32', 78 'lines': 'repeated lines', 79} 80 81DOMATO_TO_CPP_HANDLERS = { 82 'int': 'handle_int_conversion<int32_t, int>', 83 'int32': 'handle_int_conversion<int32_t, int32_t>', 84 'uint32': 'handle_int_conversion<uint32_t, uint32_t>', 85 'int8': 'handle_int_conversion<int32_t, int8_t>', 86 'uint8': 'handle_int_conversion<uint32_t, uint8_t>', 87 'int16': 'handle_int_conversion<int16_t, int16_t>', 88 'uint16': 'handle_int_conversion<uint16_t, uint16_t>', 89 'int64': 'handle_int_conversion<int64_t, int64_t>', 90 'uint64': 'handle_int_conversion<uint64_t, uint64_t>', 91 'float': 'handle_float', 92 'double': 'handle_double', 93 'char': 'handle_char', 94 'string': 'handle_string', 95 'htmlsafestring': 'handle_string', 96 'hex': 'handle_hex', 97} 98 99_C_STR_TRANS = str.maketrans({ 100 '\n': '\\n', 101 '\r': '\\r', 102 '\t': '\\t', 103 '\"': '\\\"', 104 '\\': '\\\\' 105}) 106 107BASE_PROTO_NS = 'domatolpm.generated' 108 109 110def to_cpp_ns(proto_ns: str) -> str: 111 return proto_ns.replace('.', '::') 112 113 114CPP_HANDLER_PREFIX = 'handle_' 115 116 117def to_proto_field_name(name: str) -> str: 118 """Converts a creator or rule name to a proto field name. This tries to 119 respect the protobuf naming convention that field names should be snake case. 120 121 Args: 122 name: the name of the creator or the rule. 123 124 Returns: 125 the proto field name to use. 126 """ 127 res = to_snake_case(name.replace('-', '_')) 128 if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']: 129 res += '_proto' 130 return res 131 132 133def to_proto_type(creator_name: str) -> str: 134 """Converts a creator name to a proto type. This is deliberately very simple 135 so that we avoid naming conflicts. 136 137 Args: 138 creator_name: the name of the creator. 139 140 Returns: 141 the name of the proto type. 142 """ 143 res = creator_name.replace('-', '_') 144 if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']: 145 res += '_proto' 146 return res 147 148 149def c_escape(v: str) -> str: 150 return v.translate(_C_STR_TRANS) 151 152 153@dataclasses.dataclass 154class ProtoType: 155 """Represents a Proto type.""" 156 name: str 157 158 def is_one_of(self) -> bool: 159 return False 160 161 162@dataclasses.dataclass 163class ProtoField: 164 """Represents a proto message field.""" 165 type: ProtoType 166 name: str 167 proto_id: int 168 169 170@dataclasses.dataclass 171class ProtoMessage(ProtoType): 172 """Represents a Proto message.""" 173 fields: typing.List[ProtoField] 174 175 176@dataclasses.dataclass 177class OneOfProtoMessage(ProtoMessage): 178 """Represents a Proto message with a oneof field.""" 179 oneofname: str 180 181 def is_one_of(self) -> bool: 182 return True 183 184 185class CppExpression: 186 187 def repr(self): 188 raise Exception('Not implemented.') 189 190 191@dataclasses.dataclass 192class CppTxtExpression(CppExpression): 193 """Represents a Raw text expression.""" 194 content: str 195 196 def repr(self): 197 return self.content 198 199 200@dataclasses.dataclass 201class CppCallExpr(CppExpression): 202 """Represents a CallExpr.""" 203 fct_name: str 204 args: typing.List[CppExpression] 205 ns: str = '' 206 207 def repr(self): 208 arg_s = ', '.join([a.repr() for a in self.args]) 209 return f'{self.ns}{self.fct_name}({arg_s})' 210 211 212class CppHandlerCallExpr(CppCallExpr): 213 214 def __init__(self, 215 handler: str, 216 field_name: str, 217 extra_args: typing.Optional[typing.List[CppExpression]] = None): 218 args = [CppTxtExpression('ctx'), CppTxtExpression(f'arg.{field_name}()')] 219 if extra_args: 220 args += extra_args 221 super().__init__(fct_name=handler, args=args) 222 self.handler = handler 223 self.field_name = field_name 224 self.extra_args = extra_args 225 226 227@dataclasses.dataclass 228class CppStringExpr(CppExpression): 229 """Represents a C++ literal string. 230 """ 231 content: str 232 233 def repr(self): 234 return f'\"{c_escape(self.content)}\"' 235 236 237@dataclasses.dataclass 238class CppFunctionHandler: 239 """Represents a C++ function. 240 """ 241 name: str 242 exprs: typing.List[CppExpression] 243 244 def is_oneof_handler(self) -> bool: 245 return False 246 247 def is_string_table_handler(self) -> bool: 248 return False 249 250 def is_message_handler(self) -> bool: 251 return False 252 253 254class CppStringTableHandler(CppFunctionHandler): 255 """Represents a C++ function that implements a string table and returns one 256 of the represented strings. 257 """ 258 259 def __init__(self, name: str, var_name: str, 260 strings: typing.List[CppStringExpr]): 261 super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[]) 262 self.proto_type = f'{name}& arg' 263 self.strings = strings 264 self.var_name = var_name 265 266 def is_string_table_handler(self) -> bool: 267 return True 268 269 270class CppProtoMessageFunctionHandler(CppFunctionHandler): 271 """Represents a C++ function that handles a ProtoMessage. 272 """ 273 274 def __init__(self, 275 name: str, 276 exprs: typing.List[CppExpression], 277 creator: typing.Optional[typing.Dict[str, str]] = None): 278 super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=exprs) 279 self.proto_type = f'{name}& arg' 280 self.creator = creator 281 282 def creates_new(self): 283 return self.creator is not None 284 285 def is_message_handler(self) -> bool: 286 return True 287 288 289class CppOneOfMessageFunctionHandler(CppFunctionHandler): 290 """Represents a C++ function that handles a OneOfProtoMessage. 291 """ 292 293 def __init__(self, name: str, switch_name: str, 294 cases: typing.Dict[str, typing.List[CppExpression]]): 295 super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[]) 296 self.proto_type = f'{name}& arg' 297 self.switch_name = switch_name 298 self.cases = cases 299 300 def all_except_last(self): 301 a = list(self.cases.keys())[:-1] 302 return {e: self.cases[e] for e in a} 303 304 def last(self): 305 a = list(self.cases.keys())[-1] 306 return self.cases[a] 307 308 def is_oneof_handler(self) -> bool: 309 return True 310 311 312class DomatoBuilder: 313 """DomatoBuilder is the class that takes a Domato grammar, and modelize it 314 into a protobuf representation and its corresponding C++ parsing code. 315 """ 316 317 @dataclasses.dataclass 318 class Entry: 319 msg: ProtoMessage 320 func: CppFunctionHandler 321 322 def __init__(self, g: grammar.Grammar): 323 self.handlers: typing.Dict[str, DomatoBuilder.Entry] = {} 324 self.backrefs: typing.Dict[str, typing.List[str]] = {} 325 self.grammar = g 326 if self.grammar._root and self.grammar._root != 'root': 327 self.root = self.grammar._root 328 else: 329 self.root = 'lines' 330 if self.grammar._root and self.grammar._root == 'root': 331 rules = self.grammar._creators[self.grammar._root] 332 # multiple roots doesn't make sense, so we only consider the last defined 333 # one. 334 rule = rules[-1] 335 for part in rule['parts']: 336 if part['type'] == 'tag' and part[ 337 'tagname'] == 'lines' and 'count' in part: 338 self.root = f'lines_{part["count"]}' 339 break 340 self._built_in_types_parser = { 341 'int': self._int_handler, 342 'int32': self._int_handler, 343 'uint32': self._int_handler, 344 'int8': self._int_handler, 345 'uint8': self._int_handler, 346 'int16': self._int_handler, 347 'uint16': self._int_handler, 348 'int64': self._int_handler, 349 'uint64': self._int_handler, 350 'float': self._default_handler, 351 'double': self._default_handler, 352 'char': self._default_handler, 353 'string': self._default_handler, 354 'htmlsafestring': self._default_handler, 355 'hex': self._default_handler, 356 'lines': self._lines_handler, 357 } 358 359 def parse_grammar(self): 360 for creator, rules in self.grammar._creators.items(): 361 field_name = to_proto_field_name(creator) 362 type_name = to_proto_type(creator) 363 messages = self._parse_rule(creator, rules) 364 proto_fields: typing.List[ProtoField] = [] 365 for proto_id, msg in enumerate(messages, start=1): 366 proto_fields.append( 367 ProtoField(type=ProtoType(name=msg.name), 368 name=f'{field_name}_{proto_id}', 369 proto_id=proto_id)) 370 msg = OneOfProtoMessage(name=type_name, 371 oneofname='oneoffield', 372 fields=proto_fields) 373 cases = { 374 f.name: [ 375 CppHandlerCallExpr(handler=f'{CPP_HANDLER_PREFIX}{f.type.name}', 376 field_name=f.name) 377 ] 378 for f in proto_fields 379 } 380 func = CppOneOfMessageFunctionHandler(name=type_name, 381 switch_name='oneoffield', 382 cases=cases) 383 self._add(msg, func) 384 385 def all_proto_messages(self): 386 return [v.msg for v in self.handlers.values()] 387 388 def all_cpp_functions(self): 389 return [v.func for v in self.handlers.values()] 390 391 def get_line_prefix(self) -> str: 392 if not self.grammar._line_guard: 393 return '' 394 return self.grammar._line_guard.split('<line>')[0] 395 396 def get_line_suffix(self) -> str: 397 if not self.grammar._line_guard: 398 return '' 399 return self.grammar._line_guard.split('<line>')[1] 400 401 def should_generate_repeated_lines(self): 402 return self.root == 'lines' 403 404 def should_generate_one_line_handler(self): 405 return self.root.startswith('lines') 406 407 def maybe_add_lines_handler(self, number: int) -> bool: 408 name = f'lines_{number}' 409 if name in self.handlers: 410 return False 411 fields = [] 412 exprs = [] 413 for i in range(1, number + 1): 414 fields.append(ProtoField(ProtoType('line'), f'line_{i}', i)) 415 exprs.append(CppHandlerCallExpr('handle_one_line', f'line_{i}')) 416 msg = ProtoMessage(name, fields=fields) 417 handler = CppProtoMessageFunctionHandler(name, exprs=exprs) 418 self.handlers[name] = DomatoBuilder.Entry(msg, handler) 419 return True 420 421 def get_roots(self) -> typing.Tuple[ProtoMessage, CppFunctionHandler]: 422 root = self.root 423 root_handler = f'{CPP_HANDLER_PREFIX}{root}' 424 fuzz_case = ProtoMessage( 425 name='fuzzcase', 426 fields=[ProtoField(type=ProtoType(name=root), name='root', proto_id=1)]) 427 fuzz_fct = CppProtoMessageFunctionHandler( 428 name='fuzzcase', 429 exprs=[CppHandlerCallExpr(handler=root_handler, field_name='root')]) 430 return fuzz_case, fuzz_fct 431 432 def get_protos(self) -> typing.Tuple[typing.List[ProtoMessage]]: 433 if self.should_generate_one_line_handler(): 434 # We're handling a code grammar. 435 roots = [v.msg for k, v in self.handlers.items() if k.startswith('line')] 436 roots.append(self.get_roots()[0]) 437 non_roots = [ 438 v.msg for k, v in self.handlers.items() if not k.startswith('line') 439 ] 440 return roots, non_roots 441 return [self.get_roots()[0]], self.all_proto_messages() 442 443 def simplify(self): 444 """Simplifies the proto and functions.""" 445 should_continue = True 446 while should_continue: 447 should_continue = False 448 should_continue |= self._merge_unary_oneofs() 449 should_continue |= self._merge_strings() 450 should_continue |= self._merge_multistrings_oneofs() 451 should_continue |= self._remove_unlinked_nodes() 452 should_continue |= self._merge_proto_messages() 453 should_continue |= self._merge_oneofs() 454 self._oneofs_reorderer() 455 self._oneof_message_renamer() 456 self._message_renamer() 457 458 def _add(self, message: ProtoMessage, 459 handler: CppProtoMessageFunctionHandler): 460 self.handlers[message.name] = DomatoBuilder.Entry(message, handler) 461 for field in message.fields: 462 if not field.type.name in self.backrefs: 463 self.backrefs[field.type.name] = [] 464 self.backrefs[field.type.name].append(message.name) 465 466 def _int_handler( 467 self, part, 468 field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: 469 proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']] 470 handler = DOMATO_TO_CPP_HANDLERS[part['tagname']] 471 extra_args = [] 472 if 'min' in part: 473 extra_args.append(CppTxtExpression(part['min'])) 474 if 'max' in part: 475 if not extra_args: 476 cpp_type = DOMATO_INT_TYPE_TO_CPP_INT_TYPE[part['tagname']] 477 extra_args.append( 478 CppTxtExpression(f'std::numeric_limits<{cpp_type}>::min()')) 479 extra_args.append(CppTxtExpression(part['max'])) 480 contents = CppHandlerCallExpr(handler=handler, 481 field_name=field_name, 482 extra_args=extra_args) 483 return proto_type, contents 484 485 def _lines_handler( 486 self, part, 487 field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: 488 handler_name = 'lines' 489 if 'count' in part: 490 count = part['count'] 491 handler_name = f'{handler_name}_{count}' 492 self.maybe_add_lines_handler(int(part['count'])) 493 proto_type = handler_name 494 contents = CppHandlerCallExpr(handler=f'handle_{handler_name}', 495 field_name=field_name) 496 return proto_type, contents 497 498 def _default_handler( 499 self, part, 500 field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: 501 proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']] 502 handler = DOMATO_TO_CPP_HANDLERS[part['tagname']] 503 contents = CppHandlerCallExpr(handler=handler, field_name=field_name) 504 return proto_type, contents 505 506 def _parse_rule(self, creator_name, rules): 507 messages = [] 508 for rule_id, rule in enumerate(rules, start=1): 509 rule_msg_field_name = f'{to_proto_field_name(creator_name)}_{rule_id}' 510 proto_fields = [] 511 cpp_contents = [] 512 ret_vars = 0 513 for part_id, part in enumerate(rule['parts'], start=1): 514 field_name = f'{rule_msg_field_name}_{part_id}' 515 proto_type = None 516 if rule['type'] == 'code' and 'new' in part: 517 proto_fields.insert( 518 0, 519 ProtoField(type=ProtoType('optional int32'), 520 name='old', 521 proto_id=part_id)) 522 ret_vars += 1 523 continue 524 if part['type'] == 'text': 525 contents = CppStringExpr(part['text']) 526 elif part['tagname'] == 'import': 527 # The current domato project is currently not handling that either in 528 # its built-in rules, and I do not plan on using the feature with 529 # newly written rules, as I think this directive has a lot of 530 # constraints with not much added value. 531 continue 532 elif part['tagname'] == 'call': 533 raise Exception( 534 'DomatoLPM does not implement <call> and <import> tags.') 535 elif part['tagname'] in self.grammar._constant_types.keys(): 536 contents = CppStringExpr( 537 self.grammar._constant_types[part['tagname']]) 538 elif part['tagname'] in self._built_in_types_parser: 539 handler = self._built_in_types_parser[part['tagname']] 540 proto_type, contents = handler(part, field_name) 541 elif part['type'] == 'tag': 542 proto_type = to_proto_type(part['tagname']) 543 contents = CppHandlerCallExpr( 544 handler=f'{CPP_HANDLER_PREFIX}{proto_type}', 545 field_name=field_name) 546 if proto_type: 547 proto_fields.append( 548 ProtoField(type=ProtoType(name=proto_type), 549 name=field_name, 550 proto_id=part_id)) 551 cpp_contents.append(contents) 552 553 if ret_vars > 1: 554 raise Exception('Not implemented.') 555 556 creator = None 557 if rule['type'] == 'code' and ret_vars > 0: 558 creator = {'var_type': creator_name, 'var_prefix': 'var'} 559 proto_type = to_proto_type(creator_name) 560 rule_msg = ProtoMessage(name=f'{proto_type}_{rule_id}', 561 fields=proto_fields) 562 rule_func = CppProtoMessageFunctionHandler(name=f'{proto_type}_{rule_id}', 563 exprs=cpp_contents, 564 creator=creator) 565 566 self._add(rule_msg, rule_func) 567 messages.append(rule_msg) 568 return messages 569 570 def _remove(self, name: str): 571 assert name in self.handlers 572 for field in self.handlers[name].msg.fields: 573 if field.type.name in self.backrefs: 574 self.backrefs[field.type.name].remove(name) 575 if name in self.backrefs: 576 self.backrefs.pop(name) 577 self.handlers.pop(name) 578 579 def _update(self, name: str): 580 assert name in self.handlers 581 for field in self.handlers[name].msg.fields: 582 if not field.type.name in self.backrefs: 583 self.backrefs[field.type.name] = [] 584 self.backrefs[field.type.name].append(name) 585 586 def _count_backref(self, proto_name: str) -> int: 587 """Counts the number of backreference a given proto message has. 588 589 Args: 590 proto_name: the proto message name. 591 592 Returns: 593 the number of backreferences. 594 """ 595 return len(self.backrefs[proto_name]) 596 597 def _merge_proto_messages(self) -> bool: 598 """Merges messages referencing other messages into the same message. This 599 allows to tremendously reduce the number of protobuf messages that will be 600 generated. 601 """ 602 to_merge = collections.defaultdict(set) 603 for name in self.handlers: 604 msg = self.handlers[name].msg 605 func = self.handlers[name].func 606 if msg.is_one_of() or not func.is_message_handler() or func.creates_new( 607 ) or self._is_root_node(name): 608 continue 609 if name not in self.backrefs: 610 continue 611 for elt in self.backrefs[name]: 612 if elt == name or elt not in self.handlers: 613 continue 614 if self.handlers[elt].msg.is_one_of(): 615 continue 616 to_merge[elt].add(name) 617 618 for parent, childs in to_merge.items(): 619 msg = self.handlers[parent].msg 620 fct = self.handlers[parent].func 621 for child in childs: 622 new_contents = [] 623 for expr in fct.exprs: 624 if isinstance(expr, CppStringExpr): 625 new_contents.append(expr) 626 continue 627 assert isinstance(expr, CppHandlerCallExpr) 628 field: ProtoField = next( 629 (f for f in msg.fields if f.type.name == child), None) 630 if not field or not expr.field_name == field.name: 631 new_contents.append(expr) 632 continue 633 self.backrefs[field.type.name].remove(msg.name) 634 idx = msg.fields.index(field) 635 field_msg = self.handlers[child].msg 636 field_fct = self.handlers[child].func 637 638 # The following deepcopy is required because we might change the 639 # child's messages fields at some point, and we don't want those 640 # changes to affect this current's message fields. 641 fields_copy = copy.deepcopy(field_msg.fields) 642 msg.fields = msg.fields[:idx] + fields_copy + msg.fields[idx + 1:] 643 new_contents += copy.deepcopy(field_fct.exprs) 644 for f in field_msg.fields: 645 self.backrefs[f.type.name].append(msg.name) 646 fct.exprs = new_contents 647 return len(to_merge) > 0 648 649 def _message_renamer(self): 650 """Renames ProtoMessage fields that might have been merged. This ensures 651 proto field naming remains consistent with the current rule being 652 generated. 653 """ 654 for entry in self.handlers.values(): 655 if entry.msg.is_one_of() or entry.func.is_string_table_handler(): 656 continue 657 for proto_id, field in enumerate(entry.msg.fields, start=1): 658 field.proto_id = proto_id 659 if entry.func.creates_new() and field.name == 'old': 660 continue 661 field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}') 662 index = 2 if entry.func.creates_new() else 1 663 new_contents = [] 664 for expr in entry.func.exprs: 665 if not isinstance(expr, CppHandlerCallExpr): 666 new_contents.append(expr) 667 continue 668 new_contents.append( 669 CppHandlerCallExpr(expr.handler, 670 to_proto_field_name(f'{entry.msg.name}_{index}'), 671 expr.extra_args)) 672 index += 1 673 entry.func.exprs = new_contents 674 675 def _oneof_message_renamer(self): 676 """Renames OneOfProtoMessage fields that might have been merged. This 677 ensures proto field naming remains consistent with the current rule being 678 generated. 679 """ 680 for entry in self.handlers.values(): 681 if not entry.msg.is_one_of(): 682 continue 683 cases = {} 684 for proto_id, field in enumerate(entry.msg.fields, start=1): 685 field.proto_id = proto_id 686 exprs = entry.func.cases.pop(field.name) 687 field.name = to_proto_field_name(f'{entry.msg.name}_{proto_id}') 688 new_contents = [] 689 for expr in exprs: 690 if not isinstance(expr, CppHandlerCallExpr): 691 new_contents.append(expr) 692 continue 693 new_contents.append( 694 CppHandlerCallExpr(expr.handler, field.name, expr.extra_args)) 695 cases[field.name] = new_contents 696 entry.func.cases = cases 697 698 def _merge_multistrings_oneofs(self) -> bool: 699 """Merges multiple strings into a string table function.""" 700 has_made_changes = False 701 for name in list(self.handlers.keys()): 702 msg = self.handlers[name].msg 703 704 if not msg.is_one_of(): 705 continue 706 707 if not all(f.type.name in self.handlers and len(self.handlers[ 708 f.type.name].msg.fields) == 0 and not self.handlers[f.type.name].msg. 709 is_one_of() and len(self.handlers[f.type.name].func.exprs) == 1 710 for f in msg.fields): 711 continue 712 713 fields = [ProtoField(type=ProtoType('uint32'), name='val', proto_id=1)] 714 new_msg = ProtoMessage(name=msg.name, fields=fields) 715 strings = [] 716 for field in msg.fields: 717 self.backrefs[field.type.name].remove(name) 718 for expr in self.handlers[field.type.name].func.exprs: 719 assert isinstance(expr, CppStringExpr) 720 strings += [expr] 721 new_func = CppStringTableHandler(name=msg.name, 722 var_name='val', 723 strings=strings) 724 self.handlers[name] = DomatoBuilder.Entry(new_msg, new_func) 725 self._update(name) 726 has_made_changes = True 727 return has_made_changes 728 729 def _oneofs_reorderer(self): 730 """Reorders the OneOfProtoMessage so that the last element can be extracted 731 out of the protobuf oneof's field in order to always have a correct 732 path to be generated. This requires having at least one terminal path in 733 the grammar. 734 """ 735 _terminal_messages = set() 736 _being_visited = set() 737 738 def recursive_terminal_marker(name: str): 739 if name in _terminal_messages or name not in self.handlers: 740 return True 741 if name in _being_visited: 742 return False 743 _being_visited.add(name) 744 msg = self.handlers[name].msg 745 func = self.handlers[name].func 746 if len(msg.fields) == 0: 747 _terminal_messages.add(name) 748 _being_visited.remove(name) 749 return True 750 if msg.is_one_of(): 751 f = next( 752 (f for f in msg.fields if recursive_terminal_marker(f.type.name)), 753 None) 754 if not f: 755 #FIXME: for testing purpose only, we're not hard-failing on this. 756 _being_visited.remove(name) 757 return False 758 msg.fields.remove(f) 759 msg.fields.append(f) 760 m = next(k for k in func.cases.keys() if k == f.name) 761 func.cases[m] = func.cases.pop(m) 762 _terminal_messages.add(name) 763 _being_visited.remove(name) 764 return True 765 res = all(recursive_terminal_marker(f.type.name) for f in msg.fields) 766 #FIXME: for testing purpose only, we're not hard-failing on this. 767 _being_visited.remove(name) 768 return res 769 770 for name in self.handlers: 771 recursive_terminal_marker(name) 772 773 def _merge_oneofs(self) -> bool: 774 has_made_changes = False 775 for name in list(self.handlers.keys()): 776 msg = self.handlers[name].msg 777 func = self.handlers[name].func 778 if not msg.is_one_of(): 779 continue 780 781 for field in msg.fields: 782 if not field.type.name in self.handlers: 783 continue 784 field_msg = self.handlers[field.type.name].msg 785 field_func = self.handlers[field.type.name].func 786 if field_msg.is_one_of() or len( 787 field_msg.fields) != 1 or not field_func.is_message_handler( 788 ) or field_func.creates_new(): 789 continue 790 func.cases.pop(field.name) 791 field.name = field_msg.fields[0].name 792 field.type = field_msg.fields[0].type 793 while field.name in func.cases: 794 field.name += '_1' 795 func.cases[field.name] = copy.deepcopy(field_func.exprs) 796 self.backrefs[field_msg.name].remove(name) 797 self.backrefs[field.type.name].append(name) 798 has_made_changes = True 799 return has_made_changes 800 801 def _merge_unary_oneofs(self) -> bool: 802 """Transfors OneOfProtoMessage messages containing only one field into a 803 ProtoMessage containing the fields of the contained message. E.g.: 804 message B { 805 int field1 = 1; 806 Whatever field2 = 2; 807 } 808 message A { 809 oneof field { 810 B b = 1; 811 } 812 } 813 Into: 814 message A { 815 int field1 = 1; 816 Whatever field2 = 2; 817 } 818 """ 819 has_made_changes = False 820 for name in list(self.handlers.keys()): 821 msg = self.handlers[name].msg 822 func = self.handlers[name].func 823 824 if not msg.is_one_of() or len(msg.fields) > 1: 825 continue 826 827 # The message is a unary oneof. Let's make sure it's only child doesn't 828 # have backrefs. 829 if self._count_backref(msg.fields[0].type.name) > 1: 830 continue 831 832 # The only backref should really only be us. If not we screwed up 833 # somewhere else. 834 assert name in self.backrefs[msg.fields[0].type.name] 835 field_msg: ProtoMessage = self.handlers[msg.fields[0].type.name].msg 836 if field_msg.is_one_of(): 837 continue 838 839 field_func = self.handlers[msg.fields[0].type.name].func 840 self._remove(msg.fields[0].type.name) 841 msg = ProtoMessage(name=msg.name, fields=field_msg.fields) 842 func = CppProtoMessageFunctionHandler(name=msg.name, 843 exprs=field_func.exprs, 844 creator=field_func.creator) 845 self.handlers[name] = DomatoBuilder.Entry(msg, func) 846 self._update(name) 847 has_made_changes = True 848 return has_made_changes 849 850 def _merge_strings(self) -> bool: 851 """Merges following CppString, e.g. 852 [ CppString("<first>"), CppString("<second>")] 853 Into: 854 [ CppString("<first><second>")] 855 """ 856 has_made_changes = False 857 for name in self.handlers: 858 func: CppFunctionHandler = self.handlers[name].func 859 if not func.is_message_handler() or len(func.exprs) <= 1: 860 continue 861 862 exprs = [] 863 prev = func.exprs[0] 864 for i in range(1, len(func.exprs)): 865 cur = func.exprs[i] 866 if isinstance(prev, CppStringExpr) and isinstance(cur, CppStringExpr): 867 cur = CppStringExpr(prev.content + cur.content) 868 has_made_changes = True 869 else: 870 exprs.append(prev) 871 prev = cur 872 exprs.append(prev) 873 func.exprs = exprs 874 return has_made_changes 875 876 def _is_root_node(self, name: str): 877 # If there is no existing root, we set it to `lines`, since this will 878 # be picked as the default root. 879 if 'line' not in self.root: 880 return self.root == name 881 return re.match('^line(s)?(_[0-9]*)?$', name) is not None 882 883 def _remove_unlinked_nodes(self) -> bool: 884 """Removes proto messages that are neither part of the root definition nor 885 referenced by any other messages. This can happen during other optimization 886 functions. 887 888 Returns: 889 whether a change was made. 890 """ 891 to_remove = set() 892 for name in self.handlers: 893 if name not in self.backrefs or len(self.backrefs[name]) == 0: 894 if not self._is_root_node(name): 895 to_remove.add(name) 896 local_root = 'line' if self.should_generate_one_line_handler( 897 ) else self.root 898 seen = set() 899 900 def visit_msg(msg: ProtoMessage): 901 if msg.name in seen: 902 return 903 seen.add(msg.name) 904 for field in msg.fields: 905 if field.type.name in self.handlers: 906 visit_msg(self.handlers[field.type.name].msg) 907 908 visit_msg(self.handlers[local_root].msg) 909 not_seen = set(self.handlers.keys()) - seen 910 to_remove.update(set(filter(lambda x: not self._is_root_node(x), not_seen))) 911 for t in to_remove: 912 self._remove(t) 913 return len(to_remove) > 0 914 915 916def _render_internal(template: jinja2.Template, 917 context: typing.Dict[str, typing.Any], out_f: str): 918 with action_helpers.atomic_output(out_f, mode='w') as f: 919 f.write(template.render(context)) 920 921 922def _render_proto_internal( 923 template: jinja2.Template, out_f: str, 924 proto_messages: typing.List[typing.Union[ProtoMessage, OneOfProtoMessage]], 925 should_generate_repeated_lines: bool, proto_ns: str, 926 imports: typing.List[str]): 927 _render_internal(template, { 928 'messages': [m for m in proto_messages if not m.is_one_of()], 929 'oneofmessages': [m for m in proto_messages if m.is_one_of()], 930 'generate_repeated_lines': should_generate_repeated_lines, 931 'proto_ns': proto_ns, 932 'imports': imports, 933 }, 934 out_f=out_f) 935 936 937def render_proto(environment: jinja2.Environment, generated_dir: str, 938 out_f: str, name: str, builder: DomatoBuilder): 939 template = environment.get_template('domatolpm.proto.tmpl') 940 roots, non_roots = builder.get_protos() 941 ns = f'{BASE_PROTO_NS}.{name}' 942 sub_proto_filename = pathlib.PurePosixPath(f'{out_f}_sub.proto').name 943 import_path = pathlib.PurePosixPath(generated_dir).joinpath( 944 sub_proto_filename) 945 _render_proto_internal(template, f'{out_f}.proto', roots, 946 builder.should_generate_repeated_lines(), ns, 947 [str(import_path)]) 948 _render_proto_internal(template, f'{out_f}_sub.proto', non_roots, False, ns, 949 []) 950 951 952def render_cpp(environment: jinja2.Environment, out_f: str, name: str, 953 builder: DomatoBuilder): 954 functions = builder.all_cpp_functions() 955 funcs = [f for f in functions if f.is_message_handler()] 956 oneofs = [f for f in functions if f.is_oneof_handler()] 957 stfunctions = [f for f in functions if f.is_string_table_handler()] 958 _, root_func = builder.get_roots() 959 960 rendering_context = { 961 'basename': os.path.basename(out_f), 962 'functions': funcs, 963 'oneoffunctions': oneofs, 964 'stfunctions': stfunctions, 965 'root': root_func, 966 'generate_repeated_lines': builder.should_generate_repeated_lines(), 967 'generate_one_line_handler': builder.should_generate_one_line_handler(), 968 'line_prefix': builder.get_line_prefix(), 969 'line_suffix': builder.get_line_suffix(), 970 'proto_ns': to_cpp_ns(f'{BASE_PROTO_NS}.{name}'), 971 'cpp_ns': f'domatolpm::{name}', 972 } 973 template = environment.get_template('domatolpm.cc.tmpl') 974 _render_internal(template, rendering_context, f'{out_f}.cc') 975 template = environment.get_template('domatolpm.h.tmpl') 976 _render_internal(template, rendering_context, f'{out_f}.h') 977 978 979def main(): 980 parser = argparse.ArgumentParser( 981 description= 982 'Generate the necessary files for DomatoLPM to function properly.') 983 parser.add_argument('-p', 984 '--path', 985 required=True, 986 help='The path to a Domato grammar file.') 987 parser.add_argument('-n', 988 '--name', 989 required=True, 990 help='The name of this grammar.') 991 parser.add_argument( 992 '-f', 993 '--file-format', 994 required=True, 995 help='The path prefix to which the files should be generated.') 996 parser.add_argument('-d', 997 '--generated-dir', 998 required=True, 999 help='The path to the target gen directory.') 1000 1001 args = parser.parse_args() 1002 g = grammar.Grammar() 1003 g.parse_from_file(filename=args.path) 1004 1005 template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 1006 'templates') 1007 environment = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir)) 1008 builder = DomatoBuilder(g) 1009 builder.parse_grammar() 1010 builder.simplify() 1011 render_cpp(environment, args.file_format, args.name, builder) 1012 render_proto(environment, args.generated_dir, args.file_format, args.name, 1013 builder) 1014 1015 1016if __name__ == '__main__': 1017 main() 1018