1#!/usr/bin/python3 -i 2# 3# Copyright (c) 2019 Collabora, Ltd. 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com> 8"""Provides utilities to write a script to verify XML registry consistency.""" 9 10import re 11 12import networkx as nx 13 14from .algo import RecursiveMemoize 15from .attributes import ExternSyncEntry, LengthEntry 16from .data_structures import DictOfStringSets 17from .util import findNamedElem, getElemName 18 19 20class XMLChecker: 21 def __init__(self, entity_db, conventions, manual_types_to_codes=None, 22 forward_only_types_to_codes=None, 23 reverse_only_types_to_codes=None, 24 suppressions=None): 25 """Set up data structures. 26 27 May extend - call: 28 `super().__init__(db, conventions, manual_types_to_codes)` 29 as the last statement in your function. 30 31 manual_types_to_codes is a dictionary of hard-coded 32 "manual" return codes: 33 the codes of the value are available for a command if-and-only-if 34 the key type is passed as an input. 35 36 forward_only_types_to_codes is additional entries to the above 37 that should only be used in the "forward" direction 38 (arg type implies return code) 39 40 reverse_only_types_to_codes is additional entries to 41 manual_types_to_codes that should only be used in the 42 "reverse" direction 43 (return code implies arg type) 44 """ 45 self.fail = False 46 self.entity = None 47 self.errors = DictOfStringSets() 48 self.warnings = DictOfStringSets() 49 self.db = entity_db 50 self.reg = entity_db.registry 51 self.handle_data = HandleData(self.reg) 52 self.conventions = conventions 53 54 self.CONST_RE = re.compile(r"\bconst\b") 55 self.ARRAY_RE = re.compile(r"\[[^]]+\]") 56 57 # Init memoized properties 58 self._handle_data = None 59 60 if not manual_types_to_codes: 61 manual_types_to_codes = {} 62 if not reverse_only_types_to_codes: 63 reverse_only_types_to_codes = {} 64 if not forward_only_types_to_codes: 65 forward_only_types_to_codes = {} 66 67 reverse_codes = DictOfStringSets(reverse_only_types_to_codes) 68 forward_codes = DictOfStringSets(forward_only_types_to_codes) 69 for k, v in manual_types_to_codes.items(): 70 forward_codes.add(k, v) 71 reverse_codes.add(k, v) 72 73 self.forward_only_manual_types_to_codes = forward_codes.get_dict() 74 self.reverse_only_manual_types_to_codes = reverse_codes.get_dict() 75 76 # The presence of some types as input to a function imply the 77 # availability of some return codes. 78 self.input_type_to_codes = compute_type_to_codes( 79 self.handle_data, 80 forward_codes, 81 extra_op=self.add_extra_codes) 82 83 # Some return codes require a type (or its child) in the input. 84 self.codes_requiring_input_type = compute_codes_requiring_type( 85 self.handle_data, 86 reverse_codes 87 ) 88 89 specified_codes = set(self.codes_requiring_input_type.keys()) 90 for codes in self.forward_only_manual_types_to_codes.values(): 91 specified_codes.update(codes) 92 for codes in self.reverse_only_manual_types_to_codes.values(): 93 specified_codes.update(codes) 94 for codes in self.input_type_to_codes.values(): 95 specified_codes.update(codes) 96 97 unrecognized = specified_codes - self.return_codes 98 if unrecognized: 99 raise RuntimeError("Return code mentioned in script that isn't in the registry: " + 100 ', '.join(unrecognized)) 101 102 self.referenced_input_types = ReferencedTypes(self.db, self.is_input) 103 self.referenced_api_types = ReferencedTypes(self.db, self.is_api_type) 104 if not suppressions: 105 suppressions = {} 106 self.suppressions = DictOfStringSets(suppressions) 107 108 def is_api_type(self, member_elem): 109 """Return true if the member/parameter ElementTree passed is from this API. 110 111 May override or extend.""" 112 membertext = "".join(member_elem.itertext()) 113 114 return self.conventions.type_prefix in membertext 115 116 def is_input(self, member_elem): 117 """Return true if the member/parameter ElementTree passed is 118 considered "input". 119 120 May override or extend.""" 121 membertext = "".join(member_elem.itertext()) 122 123 if self.conventions.type_prefix not in membertext: 124 return False 125 126 ret = True 127 # Const is always input. 128 if self.CONST_RE.search(membertext): 129 ret = True 130 131 # Arrays and pointers that aren't const are always output. 132 elif "*" in membertext: 133 ret = False 134 elif self.ARRAY_RE.search(membertext): 135 ret = False 136 137 return ret 138 139 def add_extra_codes(self, types_to_codes): 140 """Add any desired entries to the types-to-codes DictOfStringSets 141 before performing "ancestor propagation". 142 143 Passed to compute_type_to_codes as the extra_op. 144 145 May override.""" 146 pass 147 148 def should_skip_checking_codes(self, name): 149 """Return True if more than the basic validation of return codes should 150 be skipped for a command. 151 152 May override.""" 153 154 return self.conventions.should_skip_checking_codes 155 156 def get_codes_for_command_and_type(self, cmd_name, type_name): 157 """Return a set of error codes expected due to having 158 an input argument of type type_name. 159 160 The cmd_name is passed for use by extending methods. 161 162 May extend.""" 163 return self.input_type_to_codes.get(type_name, set()) 164 165 def check(self): 166 """Iterate through the registry, looking for consistency problems. 167 168 Outputs error messages at the end.""" 169 # Iterate through commands, looking for consistency problems. 170 for name, info in self.reg.cmddict.items(): 171 self.set_error_context(entity=name, elem=info.elem) 172 173 self.check_command(name, info) 174 175 for name, info in self.reg.typedict.items(): 176 cat = info.elem.get('category') 177 if not cat: 178 # This is an external thing, skip it. 179 continue 180 self.set_error_context(entity=name, elem=info.elem) 181 182 self.check_type(name, info, cat) 183 184 for name, info in self.reg.extdict.items(): 185 if info.elem.get('supported') != self.conventions.xml_api_name: 186 # Skip unsupported extensions 187 continue 188 self.set_error_context(entity=name, elem=info.elem) 189 self.check_extension(name, info) 190 191 entities_with_messages = set( 192 self.errors.keys()).union(self.warnings.keys()) 193 if entities_with_messages: 194 print('xml_consistency/consistency_tools error and warning messages follow.') 195 196 for entity in entities_with_messages: 197 print() 198 print('-------------------') 199 print('Messages for', entity) 200 print() 201 messages = self.errors.get(entity) 202 if messages: 203 for m in messages: 204 print('Error:', m) 205 206 messages = self.warnings.get(entity) 207 if messages: 208 for m in messages: 209 print('Warning:', m) 210 211 def check_param(self, param): 212 """Check a member of a struct or a param of a function. 213 214 Called from check_params. 215 216 May extend.""" 217 param_name = getElemName(param) 218 externsyncs = ExternSyncEntry.parse_externsync_from_param(param) 219 if externsyncs: 220 for entry in externsyncs: 221 if entry.entirely_extern_sync: 222 if len(externsyncs) > 1: 223 self.record_error("Comma-separated list in externsync attribute includes 'true' for", 224 param_name) 225 else: 226 # member name 227 # TODO only looking at the superficial feature here, 228 # not entry.param_ref_parts 229 if entry.member != param_name: 230 self.record_error("externsync attribute for", param_name, 231 "refers to some other member/parameter:", entry.member) 232 233 def check_params(self, params): 234 """Check the members of a struct or params of a function. 235 236 Called from check_type and check_command. 237 238 May extend.""" 239 for param in params: 240 self.check_param(param) 241 242 # Check for parameters referenced by len= attribute 243 lengths = LengthEntry.parse_len_from_param(param) 244 if lengths: 245 for entry in lengths: 246 if not entry.other_param_name: 247 continue 248 # TODO only looking at the superficial feature here, 249 # not entry.param_ref_parts 250 other_param = findNamedElem(params, entry.other_param_name) 251 if other_param is None: 252 self.record_error("References a non-existent parameter/member in the length of", 253 getElemName(param), ":", entry.other_param_name) 254 255 def check_type(self, name, info, category): 256 """Check a type's XML data for consistency. 257 258 Called from check. 259 260 May extend.""" 261 if category == 'struct': 262 if not name.startswith(self.conventions.type_prefix): 263 self.record_error("Name does not start with", 264 self.conventions.type_prefix) 265 members = info.elem.findall('member') 266 self.check_params(members) 267 268 # Check the structure type member, if present. 269 type_member = findNamedElem( 270 members, self.conventions.structtype_member_name) 271 if type_member is not None: 272 val = type_member.get('values') 273 if val: 274 expected = self.conventions.generate_structure_type_from_name( 275 name) 276 if val != expected: 277 self.record_error("Type has incorrect type-member value: expected", 278 expected, "got", val) 279 280 elif category == "bitmask": 281 if 'Flags' not in name: 282 self.record_error("Name of bitmask doesn't include 'Flags'") 283 284 def check_extension(self, name, info): 285 """Check an extension's XML data for consistency. 286 287 Called from check. 288 289 May extend.""" 290 pass 291 292 def check_command(self, name, info): 293 """Check a command's XML data for consistency. 294 295 Called from check. 296 297 May extend.""" 298 elem = info.elem 299 300 self.check_params(elem.findall('param')) 301 302 # Some minimal return code checking 303 errorcodes = elem.get("errorcodes") 304 if errorcodes: 305 errorcodes = errorcodes.split(",") 306 else: 307 errorcodes = [] 308 309 successcodes = elem.get("successcodes") 310 if successcodes: 311 successcodes = successcodes.split(",") 312 else: 313 successcodes = [] 314 315 if not successcodes and not errorcodes: 316 # Early out if no return codes. 317 return 318 319 # Create a set for each group of codes, and check that 320 # they aren't duplicated within or between groups. 321 errorcodes_set = set(errorcodes) 322 if len(errorcodes) != len(errorcodes_set): 323 self.record_error("Contains a duplicate in errorcodes") 324 325 successcodes_set = set(successcodes) 326 if len(successcodes) != len(successcodes_set): 327 self.record_error("Contains a duplicate in successcodes") 328 329 if not successcodes_set.isdisjoint(errorcodes_set): 330 self.record_error("Has errorcodes and successcodes that overlap") 331 332 self.check_command_return_codes_basic( 333 name, info, successcodes_set, errorcodes_set) 334 335 # Continue to further return code checking if not "complicated" 336 if not self.should_skip_checking_codes(name): 337 codes_set = successcodes_set.union(errorcodes_set) 338 self.check_command_return_codes( 339 name, info, successcodes_set, errorcodes_set, codes_set) 340 341 def check_command_return_codes_basic(self, name, info, 342 successcodes, errorcodes): 343 """Check a command's return codes for consistency. 344 345 Called from check_command on every command. 346 347 May extend.""" 348 349 # Check that all error codes include _ERROR_, 350 # and that no success codes do. 351 for code in errorcodes: 352 if "_ERROR_" not in code: 353 self.record_error( 354 code, "in errorcodes but doesn't contain _ERROR_") 355 356 for code in successcodes: 357 if "_ERROR_" in code: 358 self.record_error(code, "in successcodes but contain _ERROR_") 359 360 def check_command_return_codes(self, name, type_info, 361 successcodes, errorcodes, 362 codes): 363 """Check a command's return codes in-depth for consistency. 364 365 Called from check_command, only if 366 `self.should_skip_checking_codes(name)` is False. 367 368 May extend.""" 369 referenced_input = self.referenced_input_types[name] 370 referenced_types = self.referenced_api_types[name] 371 372 # Check that we have all the codes we expect, based on input types. 373 for referenced_type in referenced_input: 374 required_codes = self.get_codes_for_command_and_type( 375 name, referenced_type) 376 missing_codes = required_codes - codes 377 if missing_codes: 378 path = self.referenced_input_types.shortest_path( 379 name, referenced_type) 380 path_str = " -> ".join(path) 381 self.record_error("Missing expected return code(s)", 382 ",".join(missing_codes), 383 "implied because of input of type", 384 referenced_type, 385 "found via path", 386 path_str) 387 388 # Check that, for each code returned by this command that we can 389 # associate with a type, we have some type that can provide it. 390 # e.g. can't have INSTANCE_LOST without an Instance 391 # (or child of Instance). 392 for code in codes: 393 394 required_types = self.codes_requiring_input_type.get(code) 395 if not required_types: 396 # This code doesn't have a known requirement 397 continue 398 399 # TODO: do we look at referenced_types or referenced_input here? 400 # the latter is stricter 401 if not referenced_types.intersection(required_types): 402 self.record_error("Unexpected return code", code, 403 "- none of these types:", 404 required_types, 405 "found in the set of referenced types", 406 referenced_types) 407 408 ### 409 # Utility properties/methods 410 ### 411 412 def set_error_context(self, entity=None, elem=None): 413 """Set the entity and/or element for future record_error calls.""" 414 self.entity = entity 415 self.elem = elem 416 self.name = getElemName(elem) 417 self.entity_suppressions = self.suppressions.get(getElemName(elem)) 418 419 def record_error(self, *args, **kwargs): 420 """Record failure and an error message for the current context.""" 421 message = " ".join((str(x) for x in args)) 422 423 if self._is_message_suppressed(message): 424 return 425 426 message = self._prepend_sourceline_to_message(message, **kwargs) 427 self.fail = True 428 self.errors.add(self.entity, message) 429 430 def record_warning(self, *args, **kwargs): 431 """Record a warning message for the current context.""" 432 message = " ".join((str(x) for x in args)) 433 434 if self._is_message_suppressed(message): 435 return 436 437 message = self._prepend_sourceline_to_message(message, **kwargs) 438 self.warnings.add(self.entity, message) 439 440 def _is_message_suppressed(self, message): 441 """Return True if the given message, for this entity, should be suppressed.""" 442 if not self.entity_suppressions: 443 return False 444 for suppress in self.entity_suppressions: 445 if suppress in message: 446 return True 447 448 return False 449 450 def _prepend_sourceline_to_message(self, message, **kwargs): 451 """Prepend a file and/or line reference to the message, if possible. 452 453 If filename is given as a keyword argument, it is used on its own. 454 455 If filename is not given, this will attempt to retrieve the filename and line from an XML element. 456 If 'elem' is given as a keyword argument and is not None, it is used to find the line. 457 If 'elem' is given as None, no XML elements are looked at. 458 If 'elem' is not supplied, the error context element is used. 459 460 If using XML, the filename, if available, is retrieved from the Registry class. 461 If using XML and python-lxml is installed, the source line is retrieved from whatever element is chosen.""" 462 fn = kwargs.get('filename') 463 sourceline = None 464 465 if fn is None: 466 elem = kwargs.get('elem', self.elem) 467 if elem is not None: 468 sourceline = getattr(elem, 'sourceline', None) 469 if self.reg.filename: 470 fn = self.reg.filename 471 472 if fn is None and sourceline is None: 473 return message 474 475 if fn is None: 476 return "Line {}: {}".format(sourceline, message) 477 478 if sourceline is None: 479 return "{}: {}".format(fn, message) 480 481 return "{}:{}: {}".format(fn, sourceline, message) 482 483 484class HandleParents(RecursiveMemoize): 485 def __init__(self, handle_types): 486 self.handle_types = handle_types 487 488 def compute(handle_type): 489 immediate_parent = self.handle_types[handle_type].elem.get( 490 'parent') 491 492 if immediate_parent is None: 493 # No parents, no need to recurse 494 return [] 495 496 # Support multiple (alternate) parents 497 immediate_parents = immediate_parent.split(',') 498 499 # Recurse, combine, and return 500 all_parents = immediate_parents[:] 501 for parent in immediate_parents: 502 all_parents.extend(self[parent]) 503 return all_parents 504 505 super().__init__(compute, handle_types.keys()) 506 507 508def _always_true(x): 509 return True 510 511 512class ReferencedTypes(RecursiveMemoize): 513 """Find all types(optionally matching a predicate) that are referenced 514 by a struct or function, recursively.""" 515 516 def __init__(self, db, predicate=None): 517 """Initialize. 518 519 Provide an EntityDB object and a predicate function.""" 520 self.db = db 521 522 self.predicate = predicate 523 if not self.predicate: 524 # Default predicate is "anything goes" 525 self.predicate = _always_true 526 527 self._directly_referenced = {} 528 self.graph = nx.DiGraph() 529 530 def compute(type_name): 531 """Compute and return all types referenced by type_name, recursively, that satisfy the predicate. 532 533 Called by the [] operator in the base class.""" 534 types = self.directly_referenced(type_name) 535 if not types: 536 return types 537 538 all_types = set() 539 all_types.update(types) 540 for t in types: 541 referenced = self[t] 542 if referenced is not None: 543 # If not leading to a cycle 544 all_types.update(referenced) 545 return all_types 546 547 # Initialize base class 548 super().__init__(compute, permit_cycles=True) 549 550 def shortest_path(self, source, target): 551 """Get the shortest path between one type/function name and another.""" 552 # Trigger computation 553 _ = self[source] 554 555 return nx.algorithms.shortest_path(self.graph, source=source, target=target) 556 557 def directly_referenced(self, type_name): 558 """Get all types referenced directly by type_name that satisfy the predicate. 559 560 Memoizes its results.""" 561 if type_name not in self._directly_referenced: 562 members = self.db.getMemberElems(type_name) 563 if members: 564 types = ((member, member.find("type")) for member in members) 565 self._directly_referenced[type_name] = set(type_elem.text for (member, type_elem) in types 566 if type_elem is not None and self.predicate(member)) 567 568 else: 569 self._directly_referenced[type_name] = set() 570 571 # Update graph 572 self.graph.add_node(type_name) 573 self.graph.add_edges_from((type_name, t) 574 for t in self._directly_referenced[type_name]) 575 576 return self._directly_referenced[type_name] 577 578 579class HandleData: 580 """Data about all the handle types available in an API specification.""" 581 582 def __init__(self, registry): 583 self.reg = registry 584 self._handle_types = None 585 self._ancestors = None 586 self._descendants = None 587 588 @property 589 def handle_types(self): 590 """Return a dictionary of handle type names to type info.""" 591 if not self._handle_types: 592 # First time requested - compute it. 593 self._handle_types = { 594 type_name: type_info 595 for type_name, type_info in self.reg.typedict.items() 596 if type_info.elem.get('category') == 'handle' 597 } 598 return self._handle_types 599 600 @property 601 def ancestors_dict(self): 602 """Return a dictionary of handle type names to sets of ancestors.""" 603 if not self._ancestors: 604 # First time requested - compute it. 605 self._ancestors = HandleParents(self.handle_types).get_dict() 606 return self._ancestors 607 608 @property 609 def descendants_dict(self): 610 """Return a dictionary of handle type names to sets of descendants.""" 611 if not self._descendants: 612 # First time requested - compute it. 613 614 handle_parents = self.ancestors_dict 615 616 def get_descendants(handle): 617 return set(h for h in handle_parents.keys() 618 if handle in handle_parents[h]) 619 620 self._descendants = { 621 h: get_descendants(h) 622 for h in handle_parents.keys() 623 } 624 return self._descendants 625 626 627def compute_type_to_codes(handle_data, types_to_codes, extra_op=None): 628 """Compute a DictOfStringSets of input type to required return codes. 629 630 - handle_data is a HandleData instance. 631 - d is a dictionary of type names to strings or string collections of 632 return codes. 633 - extra_op, if any, is called after populating the output from the input 634 dictionary, but before propagation of parent codes to child types. 635 extra_op is called with the in-progress DictOfStringSets. 636 637 Returns a DictOfStringSets of input type name to set of required return 638 code names. 639 """ 640 # Initialize with the supplied "manual" codes 641 types_to_codes = DictOfStringSets(types_to_codes) 642 643 # Dynamically generate more codes, if desired 644 if extra_op: 645 extra_op(types_to_codes) 646 647 # Final post-processing 648 649 # Any handle can result in its parent handle's codes too. 650 651 handle_ancestors = handle_data.ancestors_dict 652 653 extra_handle_codes = {} 654 for handle_type, ancestors in handle_ancestors.items(): 655 codes = set() 656 # The sets of return codes corresponding to each ancestor type. 657 ancestors_codes = (types_to_codes.get(ancestor, set()) 658 for ancestor in ancestors) 659 codes.union(*ancestors_codes) 660 # for parent_codes in ancestors_codes: 661 # codes.update(parent_codes) 662 extra_handle_codes[handle_type] = codes 663 664 for handle_type, extras in extra_handle_codes.items(): 665 types_to_codes.add(handle_type, extras) 666 667 return types_to_codes 668 669 670def compute_codes_requiring_type(handle_data, types_to_codes, registry=None): 671 """Compute a DictOfStringSets of return codes to a set of input types able 672 to provide the ability to generate that code. 673 674 handle_data is a HandleData instance. 675 d is a dictionary of input types to associated return codes(same format 676 as for input to compute_type_to_codes, may use same dict). 677 This will invert that relationship, and also permit any "child handles" 678 to satisfy a requirement for a parent in producing a code. 679 680 Returns a DictOfStringSets of return code name to the set of parameter 681 types that would allow that return code. 682 """ 683 # Use DictOfStringSets to normalize the input into a dict with values 684 # that are sets of strings 685 in_dict = DictOfStringSets(types_to_codes) 686 687 handle_descendants = handle_data.descendants_dict 688 689 out = DictOfStringSets() 690 for in_type, code_set in in_dict.items(): 691 descendants = handle_descendants.get(in_type) 692 for code in code_set: 693 out.add(code, in_type) 694 if descendants: 695 out.add(code, descendants) 696 697 return out 698