1#!/usr/bin/python3 -i 2# 3# Copyright (c) 2019 Collabora, Ltd. 4# 5# SPDX-License-Identifier: Apache-2.0 6# 7# Author(s): Ryan Pavlik <ryan.pavlik@collabora.com> 8"""Provides utilities to write a script to verify XML registry consistency.""" 9 10import re 11from typing import Set 12 13import networkx as nx 14from networkx.algorithms import shortest_path 15 16from .algo import RecursiveMemoize 17from .attributes import ExternSyncEntry, LengthEntry 18from .data_structures import DictOfStringSets 19from .util import findNamedElem, getElemName, getElemType 20from .conventions import ConventionsBase 21 22 23def _get_extension_tags(reg): 24 """Get a set of all author tags registered for use.""" 25 return set(elt.get("name") for elt in reg.tree.findall("./tags/tag[@name]")) 26 27 28class XMLChecker: 29 def __init__(self, entity_db, conventions: ConventionsBase, manual_types_to_codes=None, 30 forward_only_types_to_codes=None, 31 reverse_only_types_to_codes=None, 32 suppressions=None): 33 """Set up data structures. 34 35 May extend - call: 36 `super().__init__(db, conventions, manual_types_to_codes)` 37 as the last statement in your function. 38 39 manual_types_to_codes is a dictionary of hard-coded 40 "manual" return codes: 41 the codes of the value are available for a command if-and-only-if 42 the key type is passed as an input. 43 44 forward_only_types_to_codes is additional entries to the above 45 that should only be used in the "forward" direction 46 (arg type implies return code) 47 48 reverse_only_types_to_codes is additional entries to 49 manual_types_to_codes that should only be used in the 50 "reverse" direction 51 (return code implies arg type) 52 """ 53 self.fail = False 54 self.entity = None 55 self.errors = DictOfStringSets() 56 self.warnings = DictOfStringSets() 57 self.db = entity_db 58 self.reg = entity_db.registry 59 self.handle_data = HandleData(self.reg) 60 self.conventions = conventions 61 62 self.CONST_RE = re.compile(r"\bconst\b") 63 self.ARRAY_RE = re.compile(r"\[[^]]+\]") 64 65 # Init memoized properties 66 self._handle_data = None 67 68 if not manual_types_to_codes: 69 manual_types_to_codes = {} 70 if not reverse_only_types_to_codes: 71 reverse_only_types_to_codes = {} 72 if not forward_only_types_to_codes: 73 forward_only_types_to_codes = {} 74 75 reverse_codes = DictOfStringSets(reverse_only_types_to_codes) 76 forward_codes = DictOfStringSets(forward_only_types_to_codes) 77 for k, v in manual_types_to_codes.items(): 78 forward_codes.add(k, v) 79 reverse_codes.add(k, v) 80 81 self.forward_only_manual_types_to_codes = forward_codes.get_dict() 82 self.reverse_only_manual_types_to_codes = reverse_codes.get_dict() 83 84 # The presence of some types as input to a function imply the 85 # availability of some return codes. 86 self.input_type_to_codes = compute_type_to_codes( 87 self.handle_data, 88 forward_codes, 89 extra_op=self.add_extra_codes) 90 91 # Some return codes require a type (or its child) in the input. 92 self.codes_requiring_input_type = compute_codes_requiring_type( 93 self.handle_data, 94 reverse_codes 95 ) 96 97 specified_codes = set(self.codes_requiring_input_type.keys()) 98 for codes in self.forward_only_manual_types_to_codes.values(): 99 specified_codes.update(codes) 100 for codes in self.reverse_only_manual_types_to_codes.values(): 101 specified_codes.update(codes) 102 for codes in self.input_type_to_codes.values(): 103 specified_codes.update(codes) 104 105 self.return_codes: Set[str] 106 unrecognized = specified_codes - self.return_codes 107 if unrecognized: 108 raise RuntimeError("Return code mentioned in script that isn't in the registry: " + 109 ', '.join(unrecognized)) 110 111 self.referenced_input_types = ReferencedTypes(self.db, self.is_input) 112 self.referenced_types = ReferencedTypes(self.db) 113 if not suppressions: 114 suppressions = {} 115 self.suppressions = DictOfStringSets(suppressions) 116 self.tags = _get_extension_tags(self.db.registry) 117 118 def is_api_type(self, member_elem): 119 """Return true if the member/parameter ElementTree passed is from this API. 120 121 May override or extend.""" 122 membertext = "".join(member_elem.itertext()) 123 124 return self.conventions.type_prefix in membertext 125 126 def is_input(self, member_elem): 127 """Return true if the member/parameter ElementTree passed is 128 considered "input". 129 130 May override or extend.""" 131 membertext = "".join(member_elem.itertext()) 132 133 if self.conventions.type_prefix not in membertext: 134 return False 135 136 ret = True 137 # Const is always input. 138 if self.CONST_RE.search(membertext): 139 ret = True 140 141 # Arrays and pointers that aren't const are always output. 142 elif "*" in membertext: 143 ret = False 144 elif self.ARRAY_RE.search(membertext): 145 ret = False 146 147 return ret 148 149 def strip_extension_tag(self, name): 150 """Remove a single author tag from the end of a name, if any. 151 152 Returns the stripped name and the tag, or the input and None if there was no tag. 153 """ 154 for t in self.tags: 155 if name.endswith(t): 156 name = name[:-(len(t))] 157 if name[-1] == "_": 158 # remove trailing underscore 159 name = name[:-1] 160 return name, t 161 return name, None 162 163 def add_extra_codes(self, types_to_codes): 164 """Add any desired entries to the types-to-codes DictOfStringSets 165 before performing "ancestor propagation". 166 167 Passed to compute_type_to_codes as the extra_op. 168 169 May override.""" 170 pass 171 172 def should_skip_checking_codes(self, name): 173 """Return True if more than the basic validation of return codes should 174 be skipped for a command. 175 176 May override.""" 177 178 return self.conventions.should_skip_checking_codes 179 180 def get_codes_for_command_and_type(self, cmd_name, type_name): 181 """Return a set of return codes expected due to having 182 an input argument of type type_name. 183 184 The cmd_name is passed for use by extending methods. 185 Note that you should not use cmd_name to add codes, just to 186 filter them out. See get_required_codes_for_command() to do that. 187 188 May extend.""" 189 return self.input_type_to_codes.get(type_name, set()) 190 191 def get_required_codes_for_command(self, cmd_name): 192 """Return a set of return codes required due to having a particular name. 193 194 May override.""" 195 return set() 196 197 def get_forbidden_codes_for_command(self, cmd_name): 198 """Return a set of return codes not permittted due to having a particular name. 199 200 May override.""" 201 return set() 202 203 def check(self): 204 """Iterate through the registry, looking for consistency problems. 205 206 Outputs error messages at the end.""" 207 # Iterate through commands, looking for consistency problems. 208 for name, info in self.reg.cmddict.items(): 209 self.set_error_context(entity=name, elem=info.elem) 210 211 self.check_command(name, info) 212 213 for name, info in self.reg.typedict.items(): 214 cat = info.elem.get('category') 215 if not cat: 216 # This is an external thing, skip it. 217 continue 218 self.set_error_context(entity=name, elem=info.elem) 219 220 self.check_type(name, info, cat) 221 222 self.ext_numbers = set() 223 for name, info in self.reg.extdict.items(): 224 self.set_error_context(entity=name, elem=info.elem) 225 226 # Determine if this extension is supported by the API we're 227 # testing, and pass that flag to check_extension. 228 # For Vulkan, multiple APIs can be specified in the 'supported' 229 # attribute. 230 supported_apis = info.elem.get('supported', '').split(',') 231 supported = self.conventions.xml_api_name in supported_apis 232 self.check_extension(name, info, supported) 233 234 self.check_format() 235 236 entities_with_messages = set( 237 self.errors.keys()).union(self.warnings.keys()) 238 if entities_with_messages: 239 print('xml_consistency/consistency_tools error and warning messages follow.') 240 241 for entity in entities_with_messages: 242 print() 243 print('-------------------') 244 print('Messages for', entity) 245 print() 246 messages = self.errors.get(entity) 247 if messages: 248 for m in messages: 249 print('Error:', m) 250 251 messages = self.warnings.get(entity) 252 if messages: 253 for m in messages: 254 print('Warning:', m) 255 256 def check_param(self, param): 257 """Check a member of a struct or a param of a function. 258 259 Called from check_params. 260 261 May extend.""" 262 param_name = getElemName(param) 263 # Make sure there's something between the type and the name 264 # Can't just look at the .tail of <type> for some reason, 265 # so instead we look to see if anything's between 266 # type's text and name's text in the itertext. 267 # If there's no text between the tags, there will be no string 268 # between those tags' text in itertext() 269 text_parts = list(param.itertext()) 270 type_idx = text_parts.index(getElemType(param)) 271 name_idx = text_parts.index(param_name) 272 if name_idx - type_idx == 1: 273 self.record_error( 274 "Space (or other delimiter text) missing between </type> and <name> for param/member named", 275 param_name) 276 277 # Check external sync entries 278 externsyncs = ExternSyncEntry.parse_externsync_from_param(param) 279 if externsyncs: 280 for entry in externsyncs: 281 if entry.entirely_extern_sync: 282 if len(externsyncs) > 1: 283 self.record_error("Comma-separated list in externsync attribute includes 'true' for", 284 param_name) 285 else: 286 # member name 287 # TODO only looking at the superficial feature here, 288 # not entry.param_ref_parts 289 if entry.member != param_name: 290 self.record_error("externsync attribute for", param_name, 291 "refers to some other member/parameter:", entry.member) 292 293 def check_params(self, params): 294 """Check the members of a struct or params of a function. 295 296 Called from check_type and check_command. 297 298 May extend.""" 299 for param in params: 300 self.check_param(param) 301 302 # Check for parameters referenced by len= attribute 303 lengths = LengthEntry.parse_len_from_param(param) 304 if lengths: 305 for entry in lengths: 306 if not entry.other_param_name: 307 continue 308 # TODO only looking at the superficial feature here, 309 # not entry.param_ref_parts 310 other_param = findNamedElem(params, entry.other_param_name) 311 if other_param is None: 312 self.record_error("References a non-existent parameter/member in the length of", 313 getElemName(param), ":", entry.other_param_name) 314 315 def check_referenced_type(self, desc, ref_name): 316 """ 317 Record an error if a type mentioned somewhere doesn't exist. 318 319 :param desc: Description of where this type reference was found, 320 for the error message. 321 :param ref_name: The name of the referenced type. If false-ish (incl. None), 322 checking is skipped, so OK to pass the results of 323 info.elem.get() directly 324 """ 325 if ref_name: 326 entity = self.db.findEntity(ref_name) 327 if not entity: 328 self.record_error("Unknown type named in", desc, ":", 329 ref_name) 330 331 def check_type(self, name, info, category): 332 """Check a type's XML data for consistency. 333 334 Called from check. 335 336 May extend.""" 337 if category == 'struct': 338 if not name.startswith(self.conventions.type_prefix): 339 self.record_error("Name does not start with", 340 self.conventions.type_prefix) 341 members = info.elem.findall('member') 342 self.check_params(members) 343 344 # Check the structure type member, if present. 345 type_member = findNamedElem( 346 members, self.conventions.structtype_member_name) 347 if type_member is not None: 348 val = type_member.get('values') 349 if val: 350 expected = self.conventions.generate_structure_type_from_name( 351 name) 352 if val != expected: 353 self.record_error("Type has incorrect type-member value: expected", 354 expected, "got", val) 355 356 # Check structextends attribute, if present. 357 # For Vulkan, this may be a comma-separated list of multiple types 358 for type in info.elem.get("structextends", '').split(','): 359 self.check_referenced_type("'structextends' attribute", type) 360 361 # Check parentstruct attribute, if present. 362 self.check_referenced_type("'parentstruct' attribute", info.elem.get("parentstruct")) 363 364 elif category == "bitmask": 365 if 'Flags' not in name: 366 self.record_error("Name of bitmask doesn't include 'Flags'") 367 elif category == "handle": 368 # Check parent attribute, if present. 369 self.check_referenced_type("'parent' attribute", info.elem.get("parent")) 370 371 def check_extension(self, name, info, supported): 372 """Check an extension's XML data for consistency. 373 374 Called from check. 375 376 May extend.""" 377 378 # Verify that each extension has a unique number 379 extension_number = info.elem.get('number') 380 if extension_number is not None and extension_number != '0': 381 if extension_number in self.ext_numbers: 382 self.record_error('Duplicate extension number ' + extension_number) 383 else: 384 self.ext_numbers.add(extension_number) 385 386 def check_format(self): 387 """Check an extension's XML data for consistency. 388 389 Called from check. 390 391 May extend.""" 392 pass 393 394 def check_command(self, name, info): 395 """Check a command's XML data for consistency. 396 397 Called from check. 398 399 May extend.""" 400 elem = info.elem 401 402 self.check_params(elem.findall('param')) 403 404 # Some minimal return code checking 405 errorcodes = elem.get("errorcodes") 406 if errorcodes: 407 errorcodes = errorcodes.split(",") 408 else: 409 errorcodes = [] 410 411 successcodes = elem.get("successcodes") 412 if successcodes: 413 successcodes = successcodes.split(",") 414 else: 415 successcodes = [] 416 417 if not successcodes and not errorcodes: 418 # Early out if no return codes. 419 return 420 421 # Create a set for each group of codes, and check that 422 # they aren't duplicated within or between groups. 423 errorcodes_set = set(errorcodes) 424 if len(errorcodes) != len(errorcodes_set): 425 self.record_error("Contains a duplicate in errorcodes") 426 427 successcodes_set = set(successcodes) 428 if len(successcodes) != len(successcodes_set): 429 self.record_error("Contains a duplicate in successcodes") 430 431 if not successcodes_set.isdisjoint(errorcodes_set): 432 self.record_error("Has errorcodes and successcodes that overlap") 433 434 self.check_command_return_codes_basic( 435 name, info, successcodes_set, errorcodes_set) 436 437 # Continue to further return code checking if not "complicated" 438 if not self.should_skip_checking_codes(name): 439 codes_set = successcodes_set.union(errorcodes_set) 440 self.check_command_return_codes( 441 name, info, successcodes_set, errorcodes_set, codes_set) 442 443 def check_command_return_codes_basic(self, name, info, 444 successcodes, errorcodes): 445 """Check a command's return codes for consistency. 446 447 Called from check_command on every command. 448 449 May extend.""" 450 451 # Check that all error codes include _ERROR_, 452 # and that no success codes do. 453 for code in errorcodes: 454 if "_ERROR_" not in code: 455 self.record_error( 456 code, "in errorcodes but doesn't contain _ERROR_") 457 458 for code in successcodes: 459 if "_ERROR_" in code: 460 self.record_error(code, "in successcodes but contain _ERROR_") 461 462 def check_command_return_codes(self, name, type_info, 463 successcodes, errorcodes, 464 codes): 465 """Check a command's return codes in-depth for consistency. 466 467 Called from check_command, only if 468 `self.should_skip_checking_codes(name)` is False. 469 470 May extend.""" 471 referenced_input = self.referenced_input_types[name] 472 referenced_types = self.referenced_types[name] 473 error_prefix = self.conventions.api_prefix + "ERROR" 474 475 bad_success = {x for x in successcodes if x.startswith(error_prefix)} 476 if bad_success: 477 self.record_error("Found error code(s)", 478 ",".join(bad_success), 479 "listed in the successcodes attributes") 480 481 bad_errors = {x for x in errorcodes if not x.startswith(error_prefix)} 482 if bad_errors: 483 self.record_error("Found success code(s)", 484 ",".join(bad_errors), 485 "listed in the errorcodes attributes") 486 487 # Check that we have all the codes we expect, based on input types. 488 for referenced_type in referenced_input: 489 required_codes = self.get_codes_for_command_and_type( 490 name, referenced_type) 491 missing_codes = required_codes - codes 492 if missing_codes: 493 path = self.referenced_input_types.shortest_path( 494 name, referenced_type) 495 path_str = " -> ".join(path) 496 self.record_error("Missing expected return code(s)", 497 ",".join(missing_codes), 498 "implied because of input of type", 499 referenced_type, 500 "found via path", 501 path_str) 502 503 # Check that we have all the codes we expect based on command name. 504 missing_codes = self.get_required_codes_for_command(name) - codes 505 if missing_codes: 506 self.record_error("Missing expected return code(s)", 507 ",".join(missing_codes), 508 "implied because of the name of this command") 509 510 # Check that we don't have any codes forbidden based on command name. 511 forbidden = self.get_forbidden_codes_for_command(name).intersection(codes) 512 if forbidden: 513 self.record_error("Got return code(s)", 514 ", ".join(forbidden), 515 "that were forbidden due to the name of this command") 516 517 # Check that, for each code returned by this command that we can 518 # associate with a type, we have some type that can provide it. 519 # e.g. can't have INSTANCE_LOST without an Instance 520 # (or child of Instance). 521 for code in codes: 522 523 required_types = self.codes_requiring_input_type.get(code) 524 if not required_types: 525 # This code doesn't have a known requirement 526 continue 527 528 # TODO: do we look at referenced_types or referenced_input here? 529 # the latter is stricter 530 if not referenced_types.intersection(required_types): 531 self.record_error("Unexpected return code", code, 532 "- none of these types:", 533 required_types, 534 "found in the set of referenced types", 535 referenced_types) 536 537 ### 538 # Utility properties/methods 539 ### 540 541 def set_error_context(self, entity=None, elem=None): 542 """Set the entity and/or element for future record_error calls.""" 543 self.entity = entity 544 self.elem = elem 545 self.name = getElemName(elem) 546 self.entity_suppressions = self.suppressions.get(getElemName(elem)) 547 548 def record_error(self, *args, **kwargs): 549 """Record failure and an error message for the current context.""" 550 message = " ".join((str(x) for x in args)) 551 552 if self._is_message_suppressed(message): 553 return 554 555 message = self._prepend_sourceline_to_message(message, **kwargs) 556 self.fail = True 557 self.errors.add(self.entity, message) 558 559 def record_warning(self, *args, **kwargs): 560 """Record a warning message for the current context.""" 561 message = " ".join((str(x) for x in args)) 562 563 if self._is_message_suppressed(message): 564 return 565 566 message = self._prepend_sourceline_to_message(message, **kwargs) 567 self.warnings.add(self.entity, message) 568 569 def _is_message_suppressed(self, message): 570 """Return True if the given message, for this entity, should be suppressed.""" 571 if not self.entity_suppressions: 572 return False 573 for suppress in self.entity_suppressions: 574 if suppress in message: 575 return True 576 577 return False 578 579 def _prepend_sourceline_to_message(self, message, **kwargs): 580 """Prepend a file and/or line reference to the message, if possible. 581 582 If filename is given as a keyword argument, it is used on its own. 583 584 If filename is not given, this will attempt to retrieve the filename and line from an XML element. 585 If 'elem' is given as a keyword argument and is not None, it is used to find the line. 586 If 'elem' is given as None, no XML elements are looked at. 587 If 'elem' is not supplied, the error context element is used. 588 589 If using XML, the filename, if available, is retrieved from the Registry class. 590 If using XML and python-lxml is installed, the source line is retrieved from whatever element is chosen.""" 591 fn = kwargs.get('filename') 592 sourceline = None 593 594 if fn is None: 595 elem = kwargs.get('elem', self.elem) 596 if elem is not None: 597 sourceline = getattr(elem, 'sourceline', None) 598 if self.reg.filename: 599 fn = self.reg.filename 600 601 if fn is None and sourceline is None: 602 return message 603 604 if fn is None: 605 return "Line {}: {}".format(sourceline, message) 606 607 if sourceline is None: 608 return "{}: {}".format(fn, message) 609 610 return "{}:{}: {}".format(fn, sourceline, message) 611 612 613class HandleParents(RecursiveMemoize): 614 def __init__(self, handle_types): 615 self.handle_types = handle_types 616 617 def compute(handle_type): 618 immediate_parent = self.handle_types[handle_type].elem.get( 619 'parent') 620 621 if immediate_parent is None: 622 # No parents, no need to recurse 623 return [] 624 625 # Support multiple (alternate) parents 626 immediate_parents = immediate_parent.split(',') 627 628 # Recurse, combine, and return 629 all_parents = immediate_parents[:] 630 for parent in immediate_parents: 631 all_parents.extend(self[parent]) 632 return all_parents 633 634 super().__init__(compute, handle_types.keys()) 635 636 637def _always_true(x): 638 return True 639 640 641class ReferencedTypes(RecursiveMemoize): 642 """Find all types(optionally matching a predicate) that are referenced 643 by a struct or function, recursively.""" 644 645 def __init__(self, db, predicate=None): 646 """Initialize. 647 648 Provide an EntityDB object and a predicate function.""" 649 self.db = db 650 651 self.predicate = predicate 652 if not self.predicate: 653 # Default predicate is "anything goes" 654 self.predicate = _always_true 655 656 self._directly_referenced = {} 657 self.graph = nx.DiGraph() 658 659 def compute(type_name): 660 """Compute and return all types referenced by type_name, recursively, that satisfy the predicate. 661 662 Called by the [] operator in the base class.""" 663 types = self.directly_referenced(type_name) 664 if not types: 665 return types 666 667 all_types = set() 668 all_types.update(types) 669 for t in types: 670 referenced = self[t] 671 if referenced is not None: 672 # If not leading to a cycle 673 all_types.update(referenced) 674 return all_types 675 676 # Initialize base class 677 super().__init__(compute, permit_cycles=True) 678 679 def shortest_path(self, source, target): 680 """Get the shortest path between one type/function name and another.""" 681 # Trigger computation 682 _ = self[source] 683 684 return shortest_path(self.graph, source=source, target=target) 685 686 def directly_referenced(self, type_name): 687 """Get all types referenced directly by type_name that satisfy the predicate. 688 689 Memoizes its results.""" 690 if type_name not in self._directly_referenced: 691 members = self.db.getMemberElems(type_name) 692 if members: 693 types = ((member, member.find("type")) for member in members) 694 self._directly_referenced[type_name] = set(type_elem.text for (member, type_elem) in types 695 if type_elem is not None and self.predicate(member)) 696 697 else: 698 self._directly_referenced[type_name] = set() 699 children = self.db.childTypes(type_name) 700 if children: 701 self._directly_referenced[type_name].update(children) 702 # Update graph 703 self.graph.add_node(type_name) 704 self.graph.add_edges_from((type_name, t) 705 for t in self._directly_referenced[type_name]) 706 707 return self._directly_referenced[type_name] 708 709 710class HandleData: 711 """Data about all the handle types available in an API specification.""" 712 713 def __init__(self, registry): 714 self.reg = registry 715 self._handle_types = None 716 self._ancestors = None 717 self._descendants = None 718 719 @property 720 def handle_types(self): 721 """Return a dictionary of handle type names to type info.""" 722 if not self._handle_types: 723 # First time requested - compute it. 724 self._handle_types = { 725 type_name: type_info 726 for type_name, type_info in self.reg.typedict.items() 727 if type_info.elem.get('category') == 'handle' 728 } 729 return self._handle_types 730 731 @property 732 def ancestors_dict(self): 733 """Return a dictionary of handle type names to sets of ancestors.""" 734 if not self._ancestors: 735 # First time requested - compute it. 736 self._ancestors = HandleParents(self.handle_types).get_dict() 737 return self._ancestors 738 739 @property 740 def descendants_dict(self): 741 """Return a dictionary of handle type names to sets of descendants.""" 742 if not self._descendants: 743 # First time requested - compute it. 744 745 handle_parents = self.ancestors_dict 746 747 def get_descendants(handle): 748 return set(h for h in handle_parents.keys() 749 if handle in handle_parents[h]) 750 751 self._descendants = { 752 h: get_descendants(h) 753 for h in handle_parents.keys() 754 } 755 return self._descendants 756 757 758def compute_type_to_codes(handle_data, types_to_codes, extra_op=None): 759 """Compute a DictOfStringSets of input type to required return codes. 760 761 - handle_data is a HandleData instance. 762 - d is a dictionary of type names to strings or string collections of 763 return codes. 764 - extra_op, if any, is called after populating the output from the input 765 dictionary, but before propagation of parent codes to child types. 766 extra_op is called with the in-progress DictOfStringSets. 767 768 Returns a DictOfStringSets of input type name to set of required return 769 code names. 770 """ 771 # Initialize with the supplied "manual" codes 772 types_to_codes = DictOfStringSets(types_to_codes) 773 774 # Dynamically generate more codes, if desired 775 if extra_op: 776 extra_op(types_to_codes) 777 778 # Final post-processing 779 780 # Any handle can result in its parent handle's codes too. 781 782 handle_ancestors = handle_data.ancestors_dict 783 784 extra_handle_codes = {} 785 for handle_type, ancestors in handle_ancestors.items(): 786 # The sets of return codes corresponding to each ancestor type. 787 ancestors_codes = [types_to_codes.get(ancestor, set()) 788 for ancestor in ancestors] 789 extra_handle_codes[handle_type] = set().union(*ancestors_codes) 790 791 for handle_type, extras in extra_handle_codes.items(): 792 types_to_codes.add(handle_type, extras) 793 794 return types_to_codes 795 796 797def compute_codes_requiring_type(handle_data, types_to_codes, registry=None): 798 """Compute a DictOfStringSets of return codes to a set of input types able 799 to provide the ability to generate that code. 800 801 handle_data is a HandleData instance. 802 d is a dictionary of input types to associated return codes(same format 803 as for input to compute_type_to_codes, may use same dict). 804 This will invert that relationship, and also permit any "child handles" 805 to satisfy a requirement for a parent in producing a code. 806 807 Returns a DictOfStringSets of return code name to the set of parameter 808 types that would allow that return code. 809 """ 810 # Use DictOfStringSets to normalize the input into a dict with values 811 # that are sets of strings 812 in_dict = DictOfStringSets(types_to_codes) 813 814 handle_descendants = handle_data.descendants_dict 815 816 out = DictOfStringSets() 817 for in_type, code_set in in_dict.items(): 818 descendants = handle_descendants.get(in_type) 819 for code in code_set: 820 out.add(code, in_type) 821 if descendants: 822 out.add(code, descendants) 823 824 return out 825