1#!/usr/bin/python3 -i 2# 3# Copyright (c) 2019 Collabora, Ltd. 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com> 8"""Provides utilities to write a script to verify XML registry consistency.""" 9 10import re 11 12import networkx as nx 13 14from .algo import RecursiveMemoize 15from .attributes import ExternSyncEntry, LengthEntry 16from .data_structures import DictOfStringSets 17from .util import findNamedElem, getElemName 18 19 20class XMLChecker: 21 def __init__(self, entity_db, conventions, manual_types_to_codes=None, 22 forward_only_types_to_codes=None, 23 reverse_only_types_to_codes=None, 24 suppressions=None): 25 """Set up data structures. 26 27 May extend - call: 28 `super().__init__(db, conventions, manual_types_to_codes)` 29 as the last statement in your function. 30 31 manual_types_to_codes is a dictionary of hard-coded 32 "manual" return codes: 33 the codes of the value are available for a command if-and-only-if 34 the key type is passed as an input. 35 36 forward_only_types_to_codes is additional entries to the above 37 that should only be used in the "forward" direction 38 (arg type implies return code) 39 40 reverse_only_types_to_codes is additional entries to 41 manual_types_to_codes that should only be used in the 42 "reverse" direction 43 (return code implies arg type) 44 """ 45 self.fail = False 46 self.entity = None 47 self.errors = DictOfStringSets() 48 self.warnings = DictOfStringSets() 49 self.db = entity_db 50 self.reg = entity_db.registry 51 self.handle_data = HandleData(self.reg) 52 self.conventions = conventions 53 54 self.CONST_RE = re.compile(r"\bconst\b") 55 self.ARRAY_RE = re.compile(r"\[[^]]+\]") 56 57 # Init memoized properties 58 self._handle_data = None 59 60 if not manual_types_to_codes: 61 manual_types_to_codes = {} 62 if not reverse_only_types_to_codes: 63 reverse_only_types_to_codes = {} 64 if not forward_only_types_to_codes: 65 forward_only_types_to_codes = {} 66 67 reverse_codes = DictOfStringSets(reverse_only_types_to_codes) 68 forward_codes = DictOfStringSets(forward_only_types_to_codes) 69 for k, v in manual_types_to_codes.items(): 70 forward_codes.add(k, v) 71 reverse_codes.add(k, v) 72 73 self.forward_only_manual_types_to_codes = forward_codes.get_dict() 74 self.reverse_only_manual_types_to_codes = reverse_codes.get_dict() 75 76 # The presence of some types as input to a function imply the 77 # availability of some return codes. 78 self.input_type_to_codes = compute_type_to_codes( 79 self.handle_data, 80 forward_codes, 81 extra_op=self.add_extra_codes) 82 83 # Some return codes require a type (or its child) in the input. 84 self.codes_requiring_input_type = compute_codes_requiring_type( 85 self.handle_data, 86 reverse_codes 87 ) 88 89 specified_codes = set(self.codes_requiring_input_type.keys()) 90 for codes in self.forward_only_manual_types_to_codes.values(): 91 specified_codes.update(codes) 92 for codes in self.reverse_only_manual_types_to_codes.values(): 93 specified_codes.update(codes) 94 for codes in self.input_type_to_codes.values(): 95 specified_codes.update(codes) 96 97 unrecognized = specified_codes - self.return_codes 98 if unrecognized: 99 raise RuntimeError("Return code mentioned in script that isn't in the registry: " + 100 ', '.join(unrecognized)) 101 102 self.referenced_input_types = ReferencedTypes(self.db, self.is_input) 103 self.referenced_api_types = ReferencedTypes(self.db, self.is_api_type) 104 if not suppressions: 105 suppressions = {} 106 self.suppressions = DictOfStringSets(suppressions) 107 108 def is_api_type(self, member_elem): 109 """Return true if the member/parameter ElementTree passed is from this API. 110 111 May override or extend.""" 112 membertext = "".join(member_elem.itertext()) 113 114 return self.conventions.type_prefix in membertext 115 116 def is_input(self, member_elem): 117 """Return true if the member/parameter ElementTree passed is 118 considered "input". 119 120 May override or extend.""" 121 membertext = "".join(member_elem.itertext()) 122 123 if self.conventions.type_prefix not in membertext: 124 return False 125 126 ret = True 127 # Const is always input. 128 if self.CONST_RE.search(membertext): 129 ret = True 130 131 # Arrays and pointers that aren't const are always output. 132 elif "*" in membertext: 133 ret = False 134 elif self.ARRAY_RE.search(membertext): 135 ret = False 136 137 return ret 138 139 def add_extra_codes(self, types_to_codes): 140 """Add any desired entries to the types-to-codes DictOfStringSets 141 before performing "ancestor propagation". 142 143 Passed to compute_type_to_codes as the extra_op. 144 145 May override.""" 146 pass 147 148 def should_skip_checking_codes(self, name): 149 """Return True if more than the basic validation of return codes should 150 be skipped for a command. 151 152 May override.""" 153 154 return self.conventions.should_skip_checking_codes 155 156 def get_codes_for_command_and_type(self, cmd_name, type_name): 157 """Return a set of error codes expected due to having 158 an input argument of type type_name. 159 160 The cmd_name is passed for use by extending methods. 161 162 May extend.""" 163 return self.input_type_to_codes.get(type_name, set()) 164 165 def check(self): 166 """Iterate through the registry, looking for consistency problems. 167 168 Outputs error messages at the end.""" 169 # Iterate through commands, looking for consistency problems. 170 for name, info in self.reg.cmddict.items(): 171 self.set_error_context(entity=name, elem=info.elem) 172 173 self.check_command(name, info) 174 175 for name, info in self.reg.typedict.items(): 176 cat = info.elem.get('category') 177 if not cat: 178 # This is an external thing, skip it. 179 continue 180 self.set_error_context(entity=name, elem=info.elem) 181 182 self.check_type(name, info, cat) 183 184 # check_extension is called for all extensions, even 'disabled' 185 # ones, but some checks may be skipped depending on extension 186 # status. 187 for name, info in self.reg.extdict.items(): 188 self.set_error_context(entity=name, elem=info.elem) 189 self.check_extension(name, info) 190 191 self.check_format() 192 193 entities_with_messages = set( 194 self.errors.keys()).union(self.warnings.keys()) 195 if entities_with_messages: 196 print('xml_consistency/consistency_tools error and warning messages follow.') 197 198 for entity in entities_with_messages: 199 print() 200 print('-------------------') 201 print('Messages for', entity) 202 print() 203 messages = self.errors.get(entity) 204 if messages: 205 for m in messages: 206 print('Error:', m) 207 208 messages = self.warnings.get(entity) 209 if messages: 210 for m in messages: 211 print('Warning:', m) 212 213 def check_param(self, param): 214 """Check a member of a struct or a param of a function. 215 216 Called from check_params. 217 218 May extend.""" 219 param_name = getElemName(param) 220 externsyncs = ExternSyncEntry.parse_externsync_from_param(param) 221 if externsyncs: 222 for entry in externsyncs: 223 if entry.entirely_extern_sync: 224 if len(externsyncs) > 1: 225 self.record_error("Comma-separated list in externsync attribute includes 'true' for", 226 param_name) 227 else: 228 # member name 229 # TODO only looking at the superficial feature here, 230 # not entry.param_ref_parts 231 if entry.member != param_name: 232 self.record_error("externsync attribute for", param_name, 233 "refers to some other member/parameter:", entry.member) 234 235 def check_params(self, params): 236 """Check the members of a struct or params of a function. 237 238 Called from check_type and check_command. 239 240 May extend.""" 241 for param in params: 242 self.check_param(param) 243 244 # Check for parameters referenced by len= attribute 245 lengths = LengthEntry.parse_len_from_param(param) 246 if lengths: 247 for entry in lengths: 248 if not entry.other_param_name: 249 continue 250 # TODO only looking at the superficial feature here, 251 # not entry.param_ref_parts 252 other_param = findNamedElem(params, entry.other_param_name) 253 if other_param is None: 254 self.record_error("References a non-existent parameter/member in the length of", 255 getElemName(param), ":", entry.other_param_name) 256 257 def check_type(self, name, info, category): 258 """Check a type's XML data for consistency. 259 260 Called from check. 261 262 May extend.""" 263 if category == 'struct': 264 if not name.startswith(self.conventions.type_prefix): 265 self.record_error("Name does not start with", 266 self.conventions.type_prefix) 267 members = info.elem.findall('member') 268 self.check_params(members) 269 270 # Check the structure type member, if present. 271 type_member = findNamedElem( 272 members, self.conventions.structtype_member_name) 273 if type_member is not None: 274 val = type_member.get('values') 275 if val: 276 expected = self.conventions.generate_structure_type_from_name( 277 name) 278 if val != expected: 279 self.record_error("Type has incorrect type-member value: expected", 280 expected, "got", val) 281 282 elif category == "bitmask": 283 if 'Flags' not in name: 284 self.record_error("Name of bitmask doesn't include 'Flags'") 285 286 def check_extension(self, name, info): 287 """Check an extension's XML data for consistency. 288 289 Called from check. 290 291 May extend.""" 292 pass 293 294 def check_format(self): 295 """Check an extension's XML data for consistency. 296 297 Called from check. 298 299 May extend.""" 300 pass 301 302 def check_command(self, name, info): 303 """Check a command's XML data for consistency. 304 305 Called from check. 306 307 May extend.""" 308 elem = info.elem 309 310 self.check_params(elem.findall('param')) 311 312 # Some minimal return code checking 313 errorcodes = elem.get("errorcodes") 314 if errorcodes: 315 errorcodes = errorcodes.split(",") 316 else: 317 errorcodes = [] 318 319 successcodes = elem.get("successcodes") 320 if successcodes: 321 successcodes = successcodes.split(",") 322 else: 323 successcodes = [] 324 325 if not successcodes and not errorcodes: 326 # Early out if no return codes. 327 return 328 329 # Create a set for each group of codes, and check that 330 # they aren't duplicated within or between groups. 331 errorcodes_set = set(errorcodes) 332 if len(errorcodes) != len(errorcodes_set): 333 self.record_error("Contains a duplicate in errorcodes") 334 335 successcodes_set = set(successcodes) 336 if len(successcodes) != len(successcodes_set): 337 self.record_error("Contains a duplicate in successcodes") 338 339 if not successcodes_set.isdisjoint(errorcodes_set): 340 self.record_error("Has errorcodes and successcodes that overlap") 341 342 self.check_command_return_codes_basic( 343 name, info, successcodes_set, errorcodes_set) 344 345 # Continue to further return code checking if not "complicated" 346 if not self.should_skip_checking_codes(name): 347 codes_set = successcodes_set.union(errorcodes_set) 348 self.check_command_return_codes( 349 name, info, successcodes_set, errorcodes_set, codes_set) 350 351 def check_command_return_codes_basic(self, name, info, 352 successcodes, errorcodes): 353 """Check a command's return codes for consistency. 354 355 Called from check_command on every command. 356 357 May extend.""" 358 359 # Check that all error codes include _ERROR_, 360 # and that no success codes do. 361 for code in errorcodes: 362 if "_ERROR_" not in code: 363 self.record_error( 364 code, "in errorcodes but doesn't contain _ERROR_") 365 366 for code in successcodes: 367 if "_ERROR_" in code: 368 self.record_error(code, "in successcodes but contain _ERROR_") 369 370 def check_command_return_codes(self, name, type_info, 371 successcodes, errorcodes, 372 codes): 373 """Check a command's return codes in-depth for consistency. 374 375 Called from check_command, only if 376 `self.should_skip_checking_codes(name)` is False. 377 378 May extend.""" 379 referenced_input = self.referenced_input_types[name] 380 referenced_types = self.referenced_api_types[name] 381 382 # Check that we have all the codes we expect, based on input types. 383 for referenced_type in referenced_input: 384 required_codes = self.get_codes_for_command_and_type( 385 name, referenced_type) 386 missing_codes = required_codes - codes 387 if missing_codes: 388 path = self.referenced_input_types.shortest_path( 389 name, referenced_type) 390 path_str = " -> ".join(path) 391 self.record_error("Missing expected return code(s)", 392 ",".join(missing_codes), 393 "implied because of input of type", 394 referenced_type, 395 "found via path", 396 path_str) 397 398 # Check that, for each code returned by this command that we can 399 # associate with a type, we have some type that can provide it. 400 # e.g. can't have INSTANCE_LOST without an Instance 401 # (or child of Instance). 402 for code in codes: 403 404 required_types = self.codes_requiring_input_type.get(code) 405 if not required_types: 406 # This code doesn't have a known requirement 407 continue 408 409 # TODO: do we look at referenced_types or referenced_input here? 410 # the latter is stricter 411 if not referenced_types.intersection(required_types): 412 self.record_error("Unexpected return code", code, 413 "- none of these types:", 414 required_types, 415 "found in the set of referenced types", 416 referenced_types) 417 418 ### 419 # Utility properties/methods 420 ### 421 422 def set_error_context(self, entity=None, elem=None): 423 """Set the entity and/or element for future record_error calls.""" 424 self.entity = entity 425 self.elem = elem 426 self.name = getElemName(elem) 427 self.entity_suppressions = self.suppressions.get(getElemName(elem)) 428 429 def record_error(self, *args, **kwargs): 430 """Record failure and an error message for the current context.""" 431 message = " ".join((str(x) for x in args)) 432 433 if self._is_message_suppressed(message): 434 return 435 436 message = self._prepend_sourceline_to_message(message, **kwargs) 437 self.fail = True 438 self.errors.add(self.entity, message) 439 440 def record_warning(self, *args, **kwargs): 441 """Record a warning message for the current context.""" 442 message = " ".join((str(x) for x in args)) 443 444 if self._is_message_suppressed(message): 445 return 446 447 message = self._prepend_sourceline_to_message(message, **kwargs) 448 self.warnings.add(self.entity, message) 449 450 def _is_message_suppressed(self, message): 451 """Return True if the given message, for this entity, should be suppressed.""" 452 if not self.entity_suppressions: 453 return False 454 for suppress in self.entity_suppressions: 455 if suppress in message: 456 return True 457 458 return False 459 460 def _prepend_sourceline_to_message(self, message, **kwargs): 461 """Prepend a file and/or line reference to the message, if possible. 462 463 If filename is given as a keyword argument, it is used on its own. 464 465 If filename is not given, this will attempt to retrieve the filename and line from an XML element. 466 If 'elem' is given as a keyword argument and is not None, it is used to find the line. 467 If 'elem' is given as None, no XML elements are looked at. 468 If 'elem' is not supplied, the error context element is used. 469 470 If using XML, the filename, if available, is retrieved from the Registry class. 471 If using XML and python-lxml is installed, the source line is retrieved from whatever element is chosen.""" 472 fn = kwargs.get('filename') 473 sourceline = None 474 475 if fn is None: 476 elem = kwargs.get('elem', self.elem) 477 if elem is not None: 478 sourceline = getattr(elem, 'sourceline', None) 479 if self.reg.filename: 480 fn = self.reg.filename 481 482 if fn is None and sourceline is None: 483 return message 484 485 if fn is None: 486 return "Line {}: {}".format(sourceline, message) 487 488 if sourceline is None: 489 return "{}: {}".format(fn, message) 490 491 return "{}:{}: {}".format(fn, sourceline, message) 492 493 494class HandleParents(RecursiveMemoize): 495 def __init__(self, handle_types): 496 self.handle_types = handle_types 497 498 def compute(handle_type): 499 immediate_parent = self.handle_types[handle_type].elem.get( 500 'parent') 501 502 if immediate_parent is None: 503 # No parents, no need to recurse 504 return [] 505 506 # Support multiple (alternate) parents 507 immediate_parents = immediate_parent.split(',') 508 509 # Recurse, combine, and return 510 all_parents = immediate_parents[:] 511 for parent in immediate_parents: 512 all_parents.extend(self[parent]) 513 return all_parents 514 515 super().__init__(compute, handle_types.keys()) 516 517 518def _always_true(x): 519 return True 520 521 522class ReferencedTypes(RecursiveMemoize): 523 """Find all types(optionally matching a predicate) that are referenced 524 by a struct or function, recursively.""" 525 526 def __init__(self, db, predicate=None): 527 """Initialize. 528 529 Provide an EntityDB object and a predicate function.""" 530 self.db = db 531 532 self.predicate = predicate 533 if not self.predicate: 534 # Default predicate is "anything goes" 535 self.predicate = _always_true 536 537 self._directly_referenced = {} 538 self.graph = nx.DiGraph() 539 540 def compute(type_name): 541 """Compute and return all types referenced by type_name, recursively, that satisfy the predicate. 542 543 Called by the [] operator in the base class.""" 544 types = self.directly_referenced(type_name) 545 if not types: 546 return types 547 548 all_types = set() 549 all_types.update(types) 550 for t in types: 551 referenced = self[t] 552 if referenced is not None: 553 # If not leading to a cycle 554 all_types.update(referenced) 555 return all_types 556 557 # Initialize base class 558 super().__init__(compute, permit_cycles=True) 559 560 def shortest_path(self, source, target): 561 """Get the shortest path between one type/function name and another.""" 562 # Trigger computation 563 _ = self[source] 564 565 return nx.algorithms.shortest_path(self.graph, source=source, target=target) 566 567 def directly_referenced(self, type_name): 568 """Get all types referenced directly by type_name that satisfy the predicate. 569 570 Memoizes its results.""" 571 if type_name not in self._directly_referenced: 572 members = self.db.getMemberElems(type_name) 573 if members: 574 types = ((member, member.find("type")) for member in members) 575 self._directly_referenced[type_name] = set(type_elem.text for (member, type_elem) in types 576 if type_elem is not None and self.predicate(member)) 577 578 else: 579 self._directly_referenced[type_name] = set() 580 581 # Update graph 582 self.graph.add_node(type_name) 583 self.graph.add_edges_from((type_name, t) 584 for t in self._directly_referenced[type_name]) 585 586 return self._directly_referenced[type_name] 587 588 589class HandleData: 590 """Data about all the handle types available in an API specification.""" 591 592 def __init__(self, registry): 593 self.reg = registry 594 self._handle_types = None 595 self._ancestors = None 596 self._descendants = None 597 598 @property 599 def handle_types(self): 600 """Return a dictionary of handle type names to type info.""" 601 if not self._handle_types: 602 # First time requested - compute it. 603 self._handle_types = { 604 type_name: type_info 605 for type_name, type_info in self.reg.typedict.items() 606 if type_info.elem.get('category') == 'handle' 607 } 608 return self._handle_types 609 610 @property 611 def ancestors_dict(self): 612 """Return a dictionary of handle type names to sets of ancestors.""" 613 if not self._ancestors: 614 # First time requested - compute it. 615 self._ancestors = HandleParents(self.handle_types).get_dict() 616 return self._ancestors 617 618 @property 619 def descendants_dict(self): 620 """Return a dictionary of handle type names to sets of descendants.""" 621 if not self._descendants: 622 # First time requested - compute it. 623 624 handle_parents = self.ancestors_dict 625 626 def get_descendants(handle): 627 return set(h for h in handle_parents.keys() 628 if handle in handle_parents[h]) 629 630 self._descendants = { 631 h: get_descendants(h) 632 for h in handle_parents.keys() 633 } 634 return self._descendants 635 636 637def compute_type_to_codes(handle_data, types_to_codes, extra_op=None): 638 """Compute a DictOfStringSets of input type to required return codes. 639 640 - handle_data is a HandleData instance. 641 - d is a dictionary of type names to strings or string collections of 642 return codes. 643 - extra_op, if any, is called after populating the output from the input 644 dictionary, but before propagation of parent codes to child types. 645 extra_op is called with the in-progress DictOfStringSets. 646 647 Returns a DictOfStringSets of input type name to set of required return 648 code names. 649 """ 650 # Initialize with the supplied "manual" codes 651 types_to_codes = DictOfStringSets(types_to_codes) 652 653 # Dynamically generate more codes, if desired 654 if extra_op: 655 extra_op(types_to_codes) 656 657 # Final post-processing 658 659 # Any handle can result in its parent handle's codes too. 660 661 handle_ancestors = handle_data.ancestors_dict 662 663 extra_handle_codes = {} 664 for handle_type, ancestors in handle_ancestors.items(): 665 codes = set() 666 # The sets of return codes corresponding to each ancestor type. 667 ancestors_codes = (types_to_codes.get(ancestor, set()) 668 for ancestor in ancestors) 669 codes.union(*ancestors_codes) 670 # for parent_codes in ancestors_codes: 671 # codes.update(parent_codes) 672 extra_handle_codes[handle_type] = codes 673 674 for handle_type, extras in extra_handle_codes.items(): 675 types_to_codes.add(handle_type, extras) 676 677 return types_to_codes 678 679 680def compute_codes_requiring_type(handle_data, types_to_codes, registry=None): 681 """Compute a DictOfStringSets of return codes to a set of input types able 682 to provide the ability to generate that code. 683 684 handle_data is a HandleData instance. 685 d is a dictionary of input types to associated return codes(same format 686 as for input to compute_type_to_codes, may use same dict). 687 This will invert that relationship, and also permit any "child handles" 688 to satisfy a requirement for a parent in producing a code. 689 690 Returns a DictOfStringSets of return code name to the set of parameter 691 types that would allow that return code. 692 """ 693 # Use DictOfStringSets to normalize the input into a dict with values 694 # that are sets of strings 695 in_dict = DictOfStringSets(types_to_codes) 696 697 handle_descendants = handle_data.descendants_dict 698 699 out = DictOfStringSets() 700 for in_type, code_set in in_dict.items(): 701 descendants = handle_descendants.get(in_type) 702 for code in code_set: 703 out.add(code, in_type) 704 if descendants: 705 out.add(code, descendants) 706 707 return out 708