• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3 -i
2#
3# Copyright (c) 2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8"""Provides utilities to write a script to verify XML registry consistency."""
9
10import re
11from typing import Set
12
13import networkx as nx
14from networkx.algorithms import shortest_path
15
16from .algo import RecursiveMemoize
17from .attributes import ExternSyncEntry, LengthEntry
18from .data_structures import DictOfStringSets
19from .util import findNamedElem, getElemName, getElemType
20from .conventions import ConventionsBase
21
22
23def _get_extension_tags(reg):
24    """Get a set of all author tags registered for use."""
25    return set(elt.get("name") for elt in reg.tree.findall("./tags/tag[@name]"))
26
27
28class XMLChecker:
29    def __init__(self, entity_db,  conventions: ConventionsBase, manual_types_to_codes=None,
30                 forward_only_types_to_codes=None,
31                 reverse_only_types_to_codes=None,
32                 suppressions=None,
33                 display_warnings=True):
34        """Set up data structures.
35
36        May extend - call:
37        `super().__init__(db, conventions, manual_types_to_codes)`
38        as the last statement in your function.
39
40        manual_types_to_codes is a dictionary of hard-coded
41        "manual" return codes:
42        the codes of the value are available for a command if-and-only-if
43        the key type is passed as an input.
44
45        forward_only_types_to_codes is additional entries to the above
46        that should only be used in the "forward" direction
47        (arg type implies return code)
48
49        reverse_only_types_to_codes is additional entries to
50        manual_types_to_codes that should only be used in the
51        "reverse" direction
52        (return code implies arg type)
53        """
54        self.fail = False
55        self.entity = None
56        self.errors = DictOfStringSets()
57        self.warnings = DictOfStringSets()
58        self.db = entity_db
59        self.reg = entity_db.registry
60        self.handle_data = HandleData(self.reg)
61        self.conventions = conventions
62        self.display_warnings = display_warnings
63
64        self.CONST_RE = re.compile(r"\bconst\b")
65        self.ARRAY_RE = re.compile(r"\[[^]]+\]")
66
67        # Init memoized properties
68        self._handle_data = None
69
70        if not manual_types_to_codes:
71            manual_types_to_codes = {}
72        if not reverse_only_types_to_codes:
73            reverse_only_types_to_codes = {}
74        if not forward_only_types_to_codes:
75            forward_only_types_to_codes = {}
76
77        reverse_codes = DictOfStringSets(reverse_only_types_to_codes)
78        forward_codes = DictOfStringSets(forward_only_types_to_codes)
79        for k, v in manual_types_to_codes.items():
80            forward_codes.add(k, v)
81            reverse_codes.add(k, v)
82
83        self.forward_only_manual_types_to_codes = forward_codes.get_dict()
84        self.reverse_only_manual_types_to_codes = reverse_codes.get_dict()
85
86        # The presence of some types as input to a function imply the
87        # availability of some return codes.
88        self.input_type_to_codes = compute_type_to_codes(
89            self.handle_data,
90            forward_codes,
91            extra_op=self.add_extra_codes)
92
93        # Some return codes require a type (or its child) in the input.
94        self.codes_requiring_input_type = compute_codes_requiring_type(
95            self.handle_data,
96            reverse_codes
97        )
98
99        specified_codes = set(self.codes_requiring_input_type.keys())
100        for codes in self.forward_only_manual_types_to_codes.values():
101            specified_codes.update(codes)
102        for codes in self.reverse_only_manual_types_to_codes.values():
103            specified_codes.update(codes)
104        for codes in self.input_type_to_codes.values():
105            specified_codes.update(codes)
106
107        self.return_codes: Set[str]
108        unrecognized = specified_codes - self.return_codes
109        if unrecognized:
110            raise RuntimeError("Return code mentioned in script that isn't in the registry: " +
111                               ', '.join(unrecognized))
112
113        self.referenced_input_types = ReferencedTypes(self.db, self.is_input)
114        self.referenced_types = ReferencedTypes(self.db)
115        if not suppressions:
116            suppressions = {}
117        self.suppressions = DictOfStringSets(suppressions)
118        self.tags = _get_extension_tags(self.db.registry)
119
120    def is_api_type(self, member_elem):
121        """Return true if the member/parameter ElementTree passed is from this API.
122
123        May override or extend."""
124        membertext = "".join(member_elem.itertext())
125
126        return self.conventions.type_prefix in membertext
127
128    def is_input(self, member_elem):
129        """Return true if the member/parameter ElementTree passed is
130        considered "input".
131
132        May override or extend."""
133        membertext = "".join(member_elem.itertext())
134
135        if self.conventions.type_prefix not in membertext:
136            return False
137
138        ret = True
139        # Const is always input.
140        if self.CONST_RE.search(membertext):
141            ret = True
142
143        # Arrays and pointers that aren't const are always output.
144        elif "*" in membertext:
145            ret = False
146        elif self.ARRAY_RE.search(membertext):
147            ret = False
148
149        return ret
150
151    def strip_extension_tag(self, name):
152        """Remove a single author tag from the end of a name, if any.
153
154        Returns the stripped name and the tag, or the input and None if there was no tag.
155        """
156        for t in self.tags:
157            if name.endswith(t):
158                name = name[:-(len(t))]
159                if name[-1] == "_":
160                    # remove trailing underscore
161                    name = name[:-1]
162                return name, t
163        return name, None
164
165    def add_extra_codes(self, types_to_codes):
166        """Add any desired entries to the types-to-codes DictOfStringSets
167        before performing "ancestor propagation".
168
169        Passed to compute_type_to_codes as the extra_op.
170
171        May override."""
172        pass
173
174    def should_skip_checking_codes(self, name):
175        """Return True if more than the basic validation of return codes should
176        be skipped for a command.
177
178        May override."""
179
180        return self.conventions.should_skip_checking_codes
181
182    def get_codes_for_command_and_type(self, cmd_name, type_name):
183        """Return a set of return codes expected due to having
184        an input argument of type type_name.
185
186        The cmd_name is passed for use by extending methods.
187        Note that you should not use cmd_name to add codes, just to
188        filter them out. See get_required_codes_for_command() to do that.
189
190        May extend."""
191        return self.input_type_to_codes.get(type_name, set())
192
193    def get_required_codes_for_command(self, cmd_name):
194        """Return a set of return codes required due to having a particular name.
195
196        May override."""
197        return set()
198
199    def get_forbidden_codes_for_command(self, cmd_name):
200        """Return a set of return codes not permittted due to having a particular name.
201
202        May override."""
203        return set()
204
205    def check(self):
206        """Iterate through the registry, looking for consistency problems.
207
208        Outputs error messages at the end."""
209        # Iterate through commands, looking for consistency problems.
210        for name, info in self.reg.cmddict.items():
211            self.set_error_context(entity=name, elem=info.elem)
212
213            self.check_command(name, info)
214
215        for name, info in self.reg.typedict.items():
216            cat = info.elem.get('category')
217            if not cat:
218                # This is an external thing, skip it.
219                continue
220            self.set_error_context(entity=name, elem=info.elem)
221
222            self.check_type(name, info, cat)
223
224        self.ext_numbers = set()
225        for name, info in self.reg.extdict.items():
226            self.set_error_context(entity=name, elem=info.elem)
227
228            # Determine if this extension is supported by the API we're
229            # testing, and pass that flag to check_extension.
230            # For Vulkan, multiple APIs can be specified in the 'supported'
231            # attribute.
232            supported_apis = info.elem.get('supported', '').split(',')
233            supported = self.conventions.xml_api_name in supported_apis
234            self.check_extension(name, info, supported)
235
236        self.check_format()
237
238        entities_with_messages = set(
239            self.errors.keys()).union(self.warnings.keys())
240        if entities_with_messages:
241            print('xml_consistency/consistency_tools error and warning messages follow.')
242
243        for entity in entities_with_messages:
244            messages = self.errors.get(entity)
245            if messages:
246                print(f'\nError messages for {entity}')
247                for m in messages:
248                    print('ERROR:', m)
249
250            messages = self.warnings.get(entity)
251            if messages and self.display_warnings:
252                print(f'\nWarning messages for {entity}')
253                for m in messages:
254                    print('WARNING:', m)
255
256    def check_param(self, param):
257        """Check a member of a struct or a param of a function.
258
259        Called from check_params.
260
261        May extend."""
262        param_name = getElemName(param)
263        # Make sure there's something between the type and the name
264        # Can't just look at the .tail of <type> for some reason,
265        # so instead we look to see if anything's between
266        # type's text and name's text in the itertext.
267        # If there's no text between the tags, there will be no string
268        # between those tags' text in itertext()
269        text_parts = list(param.itertext())
270        type_idx = text_parts.index(getElemType(param))
271        name_idx = text_parts.index(param_name)
272        if name_idx - type_idx == 1:
273            self.record_error(
274                "Space (or other delimiter text) missing between </type> and <name> for param/member named",
275                param_name)
276
277        # Check external sync entries
278        externsyncs = ExternSyncEntry.parse_externsync_from_param(param)
279        if externsyncs:
280            for entry in externsyncs:
281                if entry.entirely_extern_sync:
282                    if len(externsyncs) > 1:
283                        self.record_error("Comma-separated list in externsync attribute includes 'true' for",
284                                          param_name)
285                else:
286                    # member name
287                    # TODO only looking at the superficial feature here,
288                    # not entry.param_ref_parts
289                    if entry.member != param_name:
290                        self.record_error("externsync attribute for", param_name,
291                                          "refers to some other member/parameter:", entry.member)
292
293    def check_params(self, params):
294        """Check the members of a struct or params of a function.
295
296        Called from check_type and check_command.
297
298        May extend."""
299        for param in params:
300            self.check_param(param)
301
302            # Check for parameters referenced by len= attribute
303            lengths = LengthEntry.parse_len_from_param(param)
304            if lengths:
305                for entry in lengths:
306                    if not entry.other_param_name:
307                        continue
308                    # TODO only looking at the superficial feature here,
309                    # not entry.param_ref_parts
310                    other_param = findNamedElem(params, entry.other_param_name)
311                    if other_param is None:
312                        self.record_error("References a non-existent parameter/member in the length of",
313                                          getElemName(param), ":", entry.other_param_name)
314
315    def check_referenced_type(self, desc, ref_name):
316        """
317        Record an error if a type mentioned somewhere doesn't exist.
318
319        :param desc: Description of where this type reference was found,
320                     for the error message.
321        :param ref_name: The name of the referenced type. If false-ish (incl. None),
322                         checking is skipped, so OK to pass the results of
323                         info.elem.get() directly
324        """
325        if ref_name:
326            entity = self.db.findEntity(ref_name)
327            if not entity:
328                self.record_error("Unknown type named in", desc, ":",
329                                  ref_name)
330
331    def check_type(self, name, info, category):
332        """Check a type's XML data for consistency.
333
334        Called from check.
335
336        May extend."""
337        if category == 'struct':
338            if not name.startswith(self.conventions.type_prefix):
339                self.record_error("Name does not start with",
340                                  self.conventions.type_prefix)
341            members = info.elem.findall('member')
342            self.check_params(members)
343
344            # Check the structure type member, if present.
345            type_member = findNamedElem(
346                members, self.conventions.structtype_member_name)
347            if type_member is not None:
348                val = type_member.get('values')
349                if val:
350                    expected = self.conventions.generate_structure_type_from_name(
351                        name)
352                    if val != expected:
353                        self.record_error("Type has incorrect type-member value: expected",
354                                          expected, "got", val)
355
356            # Check structextends attribute, if present.
357            # For Vulkan, this may be a comma-separated list of multiple types
358            for type in info.elem.get("structextends", '').split(','):
359                self.check_referenced_type("'structextends' attribute", type)
360
361            # Check parentstruct attribute, if present.
362            self.check_referenced_type("'parentstruct' attribute", info.elem.get("parentstruct"))
363
364        elif category == "bitmask":
365            if 'Flags' not in name:
366                self.record_error("Name of bitmask doesn't include 'Flags'")
367        elif category == "handle":
368            # Check parent attribute, if present.
369            self.check_referenced_type("'parent' attribute", info.elem.get("parent"))
370
371    def check_extension(self, name, info, supported):
372        """Check an extension's XML data for consistency.
373
374        Called from check.
375
376        May extend."""
377
378        # Verify that each extension has a unique number
379        extension_number = info.elem.get('number')
380        if extension_number is not None and extension_number != '0':
381            if extension_number in self.ext_numbers:
382                self.record_error('Duplicate extension number ' + extension_number)
383            else:
384                self.ext_numbers.add(extension_number)
385
386    def check_format(self):
387        """Check an extension's XML data for consistency.
388
389        Called from check.
390
391        May extend."""
392        pass
393
394    def check_command(self, name, info):
395        """Check a command's XML data for consistency.
396
397        Called from check.
398
399        May extend."""
400        elem = info.elem
401
402        self.check_params(elem.findall('param'))
403
404        # Some minimal return code checking
405        errorcodes = elem.get("errorcodes")
406        if errorcodes:
407            errorcodes = errorcodes.split(",")
408        else:
409            errorcodes = []
410
411        successcodes = elem.get("successcodes")
412        if successcodes:
413            successcodes = successcodes.split(",")
414        else:
415            successcodes = []
416
417        if not successcodes and not errorcodes:
418            # Early out if no return codes.
419            return
420
421        # Create a set for each group of codes, and check that
422        # they aren't duplicated within or between groups.
423        errorcodes_set = set(errorcodes)
424        if len(errorcodes) != len(errorcodes_set):
425            self.record_error("Contains a duplicate in errorcodes")
426
427        successcodes_set = set(successcodes)
428        if len(successcodes) != len(successcodes_set):
429            self.record_error("Contains a duplicate in successcodes")
430
431        if not successcodes_set.isdisjoint(errorcodes_set):
432            self.record_error("Has errorcodes and successcodes that overlap")
433
434        self.check_command_return_codes_basic(
435            name, info, successcodes_set, errorcodes_set)
436
437        # Continue to further return code checking if not "complicated"
438        if not self.should_skip_checking_codes(name):
439            codes_set = successcodes_set.union(errorcodes_set)
440            self.check_command_return_codes(
441                name, info, successcodes_set, errorcodes_set, codes_set)
442
443    def check_command_return_codes_basic(self, name, info,
444                                         successcodes, errorcodes):
445        """Check a command's return codes for consistency.
446
447        Called from check_command on every command.
448
449        May extend."""
450
451        # Check that all error codes include _ERROR_,
452        #  and that no success codes do.
453        for code in errorcodes:
454            if "_ERROR_" not in code:
455                self.record_error(
456                    code, "in errorcodes but doesn't contain _ERROR_")
457
458        for code in successcodes:
459            if "_ERROR_" in code:
460                self.record_error(code, "in successcodes but contain _ERROR_")
461
462    def check_command_return_codes(self, name, type_info,
463                                   successcodes, errorcodes,
464                                   codes):
465        """Check a command's return codes in-depth for consistency.
466
467        Called from check_command, only if
468        `self.should_skip_checking_codes(name)` is False.
469
470        May extend."""
471        referenced_input = self.referenced_input_types[name]
472        referenced_types = self.referenced_types[name]
473        error_prefix = self.conventions.api_prefix + "ERROR"
474
475        bad_success = {x for x in successcodes if x.startswith(error_prefix)}
476        if bad_success:
477            self.record_error("Found error code(s)",
478                              ",".join(bad_success),
479                              "listed in the successcodes attributes")
480
481        bad_errors = {x for x in errorcodes if not x.startswith(error_prefix)}
482        if bad_errors:
483            self.record_error("Found success code(s)",
484                              ",".join(bad_errors),
485                              "listed in the errorcodes attributes")
486
487        # Check that we have all the codes we expect, based on input types.
488        for referenced_type in referenced_input:
489            required_codes = self.get_codes_for_command_and_type(
490                name, referenced_type)
491            missing_codes = required_codes - codes
492            if missing_codes:
493                path = self.referenced_input_types.shortest_path(
494                    name, referenced_type)
495                path_str = " -> ".join(path)
496                self.record_error("Missing expected return code(s)",
497                                  ",".join(missing_codes),
498                                  "implied because of input of type",
499                                  referenced_type,
500                                  "found via path",
501                                  path_str)
502
503        # Check that we have all the codes we expect based on command name.
504        missing_codes = self.get_required_codes_for_command(name) - codes
505        if missing_codes:
506            self.record_error("Missing expected return code(s)",
507                              ",".join(missing_codes),
508                              "implied because of the name of this command")
509
510        # Check that we don't have any codes forbidden based on command name.
511        forbidden = self.get_forbidden_codes_for_command(name).intersection(codes)
512        if forbidden:
513            self.record_error("Got return code(s)",
514                              ", ".join(forbidden),
515                              "that were forbidden due to the name of this command")
516
517        # Check that, for each code returned by this command that we can
518        # associate with a type, we have some type that can provide it.
519        # e.g. can't have INSTANCE_LOST without an Instance
520        # (or child of Instance).
521        for code in codes:
522
523            required_types = self.codes_requiring_input_type.get(code)
524            if not required_types:
525                # This code doesn't have a known requirement
526                continue
527
528            # TODO: do we look at referenced_types or referenced_input here?
529            # the latter is stricter
530            if not referenced_types.intersection(required_types):
531                self.record_error("Unexpected return code", code,
532                                  "- none of these types:",
533                                  required_types,
534                                  "found in the set of referenced types",
535                                  referenced_types)
536
537    ###
538    # Utility properties/methods
539    ###
540
541    def set_error_context(self, entity=None, elem=None):
542        """Set the entity and/or element for future record_error calls."""
543        self.entity = entity
544        self.elem = elem
545        self.name = getElemName(elem)
546        self.entity_suppressions = self.suppressions.get(getElemName(elem))
547
548    def record_error(self, *args, **kwargs):
549        """Record failure and an error message for the current context."""
550        message = " ".join((str(x) for x in args))
551
552        if self._is_message_suppressed(message):
553            return
554
555        message = self._prepend_sourceline_to_message(message, **kwargs)
556        self.fail = True
557        self.errors.add(self.entity, message)
558
559    def record_warning(self, *args, **kwargs):
560        """Record a warning message for the current context."""
561        message = " ".join((str(x) for x in args))
562
563        if self._is_message_suppressed(message):
564            return
565
566        message = self._prepend_sourceline_to_message(message, **kwargs)
567        self.warnings.add(self.entity, message)
568
569    def _is_message_suppressed(self, message):
570        """Return True if the given message, for this entity, should be suppressed."""
571        if not self.entity_suppressions:
572            return False
573        for suppress in self.entity_suppressions:
574            if suppress in message:
575                return True
576
577        return False
578
579    def _prepend_sourceline_to_message(self, message, **kwargs):
580        """Prepend a file and/or line reference to the message, if possible.
581
582        If filename is given as a keyword argument, it is used on its own.
583
584        If filename is not given, this will attempt to retrieve the filename and line from an XML element.
585        If 'elem' is given as a keyword argument and is not None, it is used to find the line.
586        If 'elem' is given as None, no XML elements are looked at.
587        If 'elem' is not supplied, the error context element is used.
588
589        If using XML, the filename, if available, is retrieved from the Registry class.
590        If using XML and python-lxml is installed, the source line is retrieved from whatever element is chosen."""
591        fn = kwargs.get('filename')
592        sourceline = None
593
594        if fn is None:
595            elem = kwargs.get('elem', self.elem)
596            if elem is not None:
597                sourceline = getattr(elem, 'sourceline', None)
598                if self.reg.filename:
599                    fn = self.reg.filename
600
601        if fn is None and sourceline is None:
602            return message
603
604        if fn is None:
605            return "Line {}: {}".format(sourceline, message)
606
607        if sourceline is None:
608            return "{}: {}".format(fn, message)
609
610        return "{}:{}: {}".format(fn, sourceline, message)
611
612
613class HandleParents(RecursiveMemoize):
614    def __init__(self, handle_types):
615        self.handle_types = handle_types
616
617        def compute(handle_type):
618            immediate_parent = self.handle_types[handle_type].elem.get(
619                'parent')
620
621            if immediate_parent is None:
622                # No parents, no need to recurse
623                return []
624
625            # Support multiple (alternate) parents
626            immediate_parents = immediate_parent.split(',')
627
628            # Recurse, combine, and return
629            all_parents = immediate_parents[:]
630            for parent in immediate_parents:
631                all_parents.extend(self[parent])
632            return all_parents
633
634        super().__init__(compute, handle_types.keys())
635
636
637def _always_true(x):
638    return True
639
640
641class ReferencedTypes(RecursiveMemoize):
642    """Find all types(optionally matching a predicate) that are referenced
643    by a struct or function, recursively."""
644
645    def __init__(self, db, predicate=None):
646        """Initialize.
647
648        Provide an EntityDB object and a predicate function."""
649        self.db = db
650
651        self.predicate = predicate
652        if not self.predicate:
653            # Default predicate is "anything goes"
654            self.predicate = _always_true
655
656        self._directly_referenced = {}
657        self.graph = nx.DiGraph()
658
659        def compute(type_name):
660            """Compute and return all types referenced by type_name, recursively, that satisfy the predicate.
661
662            Called by the [] operator in the base class."""
663            types = self.directly_referenced(type_name)
664            if not types:
665                return types
666
667            all_types = set()
668            all_types.update(types)
669            for t in types:
670                referenced = self[t]
671                if referenced is not None:
672                    # If not leading to a cycle
673                    all_types.update(referenced)
674            return all_types
675
676        # Initialize base class
677        super().__init__(compute, permit_cycles=True)
678
679    def shortest_path(self, source, target):
680        """Get the shortest path between one type/function name and another."""
681        # Trigger computation
682        _ = self[source]
683
684        return shortest_path(self.graph, source=source, target=target)
685
686    def directly_referenced(self, type_name):
687        """Get all types referenced directly by type_name that satisfy the predicate.
688
689        Memoizes its results."""
690        if type_name not in self._directly_referenced:
691            members = self.db.getMemberElems(type_name)
692            if members:
693                types = ((member, member.find("type")) for member in members)
694                self._directly_referenced[type_name] = set(type_elem.text for (member, type_elem) in types
695                                                           if type_elem is not None and self.predicate(member))
696
697            else:
698                self._directly_referenced[type_name] = set()
699            children = self.db.childTypes(type_name)
700            if children:
701                self._directly_referenced[type_name].update(children)
702            # Update graph
703            self.graph.add_node(type_name)
704            self.graph.add_edges_from((type_name, t)
705                                      for t in self._directly_referenced[type_name])
706
707        return self._directly_referenced[type_name]
708
709
710class HandleData:
711    """Data about all the handle types available in an API specification."""
712
713    def __init__(self, registry):
714        self.reg = registry
715        self._handle_types = None
716        self._ancestors = None
717        self._descendants = None
718
719    @property
720    def handle_types(self):
721        """Return a dictionary of handle type names to type info."""
722        if not self._handle_types:
723            # First time requested - compute it.
724            self._handle_types = {
725                type_name: type_info
726                for type_name, type_info in self.reg.typedict.items()
727                if type_info.elem.get('category') == 'handle'
728            }
729        return self._handle_types
730
731    @property
732    def ancestors_dict(self):
733        """Return a dictionary of handle type names to sets of ancestors."""
734        if not self._ancestors:
735            # First time requested - compute it.
736            self._ancestors = HandleParents(self.handle_types).get_dict()
737        return self._ancestors
738
739    @property
740    def descendants_dict(self):
741        """Return a dictionary of handle type names to sets of descendants."""
742        if not self._descendants:
743            # First time requested - compute it.
744
745            handle_parents = self.ancestors_dict
746
747            def get_descendants(handle):
748                return set(h for h in handle_parents.keys()
749                           if handle in handle_parents[h])
750
751            self._descendants = {
752                h: get_descendants(h)
753                for h in handle_parents.keys()
754            }
755        return self._descendants
756
757
758def compute_type_to_codes(handle_data, types_to_codes, extra_op=None):
759    """Compute a DictOfStringSets of input type to required return codes.
760
761    - handle_data is a HandleData instance.
762    - d is a dictionary of type names to strings or string collections of
763      return codes.
764    - extra_op, if any, is called after populating the output from the input
765      dictionary, but before propagation of parent codes to child types.
766      extra_op is called with the in-progress DictOfStringSets.
767
768    Returns a DictOfStringSets of input type name to set of required return
769    code names.
770    """
771    # Initialize with the supplied "manual" codes
772    types_to_codes = DictOfStringSets(types_to_codes)
773
774    # Dynamically generate more codes, if desired
775    if extra_op:
776        extra_op(types_to_codes)
777
778    # Final post-processing
779
780    # Any handle can result in its parent handle's codes too.
781
782    handle_ancestors = handle_data.ancestors_dict
783
784    extra_handle_codes = {}
785    for handle_type, ancestors in handle_ancestors.items():
786        # The sets of return codes corresponding to each ancestor type.
787        ancestors_codes = [types_to_codes.get(ancestor, set())
788                           for ancestor in ancestors]
789        extra_handle_codes[handle_type] = set().union(*ancestors_codes)
790
791    for handle_type, extras in extra_handle_codes.items():
792        types_to_codes.add(handle_type, extras)
793
794    return types_to_codes
795
796
797def compute_codes_requiring_type(handle_data, types_to_codes, registry=None):
798    """Compute a DictOfStringSets of return codes to a set of input types able
799    to provide the ability to generate that code.
800
801    handle_data is a HandleData instance.
802    d is a dictionary of input types to associated return codes(same format
803    as for input to compute_type_to_codes, may use same dict).
804    This will invert that relationship, and also permit any "child handles"
805    to satisfy a requirement for a parent in producing a code.
806
807    Returns a DictOfStringSets of return code name to the set of parameter
808    types that would allow that return code.
809    """
810    # Use DictOfStringSets to normalize the input into a dict with values
811    # that are sets of strings
812    in_dict = DictOfStringSets(types_to_codes)
813
814    handle_descendants = handle_data.descendants_dict
815
816    out = DictOfStringSets()
817    for in_type, code_set in in_dict.items():
818        descendants = handle_descendants.get(in_type)
819        for code in code_set:
820            out.add(code, in_type)
821            if descendants:
822                out.add(code, descendants)
823
824    return out
825