• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3 -i
2#
3# Copyright (c) 2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8"""Provides utilities to write a script to verify XML registry consistency."""
9
10import re
11
12import networkx as nx
13
14from .algo import RecursiveMemoize
15from .attributes import ExternSyncEntry, LengthEntry
16from .data_structures import DictOfStringSets
17from .util import findNamedElem, getElemName
18
19
20class XMLChecker:
21    def __init__(self, entity_db,  conventions, manual_types_to_codes=None,
22                 forward_only_types_to_codes=None,
23                 reverse_only_types_to_codes=None,
24                 suppressions=None):
25        """Set up data structures.
26
27        May extend - call:
28        `super().__init__(db, conventions, manual_types_to_codes)`
29        as the last statement in your function.
30
31        manual_types_to_codes is a dictionary of hard-coded
32        "manual" return codes:
33        the codes of the value are available for a command if-and-only-if
34        the key type is passed as an input.
35
36        forward_only_types_to_codes is additional entries to the above
37        that should only be used in the "forward" direction
38        (arg type implies return code)
39
40        reverse_only_types_to_codes is additional entries to
41        manual_types_to_codes that should only be used in the
42        "reverse" direction
43        (return code implies arg type)
44        """
45        self.fail = False
46        self.entity = None
47        self.errors = DictOfStringSets()
48        self.warnings = DictOfStringSets()
49        self.db = entity_db
50        self.reg = entity_db.registry
51        self.handle_data = HandleData(self.reg)
52        self.conventions = conventions
53
54        self.CONST_RE = re.compile(r"\bconst\b")
55        self.ARRAY_RE = re.compile(r"\[[^]]+\]")
56
57        # Init memoized properties
58        self._handle_data = None
59
60        if not manual_types_to_codes:
61            manual_types_to_codes = {}
62        if not reverse_only_types_to_codes:
63            reverse_only_types_to_codes = {}
64        if not forward_only_types_to_codes:
65            forward_only_types_to_codes = {}
66
67        reverse_codes = DictOfStringSets(reverse_only_types_to_codes)
68        forward_codes = DictOfStringSets(forward_only_types_to_codes)
69        for k, v in manual_types_to_codes.items():
70            forward_codes.add(k, v)
71            reverse_codes.add(k, v)
72
73        self.forward_only_manual_types_to_codes = forward_codes.get_dict()
74        self.reverse_only_manual_types_to_codes = reverse_codes.get_dict()
75
76        # The presence of some types as input to a function imply the
77        # availability of some return codes.
78        self.input_type_to_codes = compute_type_to_codes(
79            self.handle_data,
80            forward_codes,
81            extra_op=self.add_extra_codes)
82
83        # Some return codes require a type (or its child) in the input.
84        self.codes_requiring_input_type = compute_codes_requiring_type(
85            self.handle_data,
86            reverse_codes
87        )
88
89        specified_codes = set(self.codes_requiring_input_type.keys())
90        for codes in self.forward_only_manual_types_to_codes.values():
91            specified_codes.update(codes)
92        for codes in self.reverse_only_manual_types_to_codes.values():
93            specified_codes.update(codes)
94        for codes in self.input_type_to_codes.values():
95            specified_codes.update(codes)
96
97        unrecognized = specified_codes - self.return_codes
98        if unrecognized:
99            raise RuntimeError("Return code mentioned in script that isn't in the registry: " +
100                               ', '.join(unrecognized))
101
102        self.referenced_input_types = ReferencedTypes(self.db, self.is_input)
103        self.referenced_api_types = ReferencedTypes(self.db, self.is_api_type)
104        if not suppressions:
105            suppressions = {}
106        self.suppressions = DictOfStringSets(suppressions)
107
108    def is_api_type(self, member_elem):
109        """Return true if the member/parameter ElementTree passed is from this API.
110
111        May override or extend."""
112        membertext = "".join(member_elem.itertext())
113
114        return self.conventions.type_prefix in membertext
115
116    def is_input(self, member_elem):
117        """Return true if the member/parameter ElementTree passed is
118        considered "input".
119
120        May override or extend."""
121        membertext = "".join(member_elem.itertext())
122
123        if self.conventions.type_prefix not in membertext:
124            return False
125
126        ret = True
127        # Const is always input.
128        if self.CONST_RE.search(membertext):
129            ret = True
130
131        # Arrays and pointers that aren't const are always output.
132        elif "*" in membertext:
133            ret = False
134        elif self.ARRAY_RE.search(membertext):
135            ret = False
136
137        return ret
138
139    def add_extra_codes(self, types_to_codes):
140        """Add any desired entries to the types-to-codes DictOfStringSets
141        before performing "ancestor propagation".
142
143        Passed to compute_type_to_codes as the extra_op.
144
145        May override."""
146        pass
147
148    def should_skip_checking_codes(self, name):
149        """Return True if more than the basic validation of return codes should
150        be skipped for a command.
151
152        May override."""
153
154        return self.conventions.should_skip_checking_codes
155
156    def get_codes_for_command_and_type(self, cmd_name, type_name):
157        """Return a set of error codes expected due to having
158        an input argument of type type_name.
159
160        The cmd_name is passed for use by extending methods.
161
162        May extend."""
163        return self.input_type_to_codes.get(type_name, set())
164
165    def check(self):
166        """Iterate through the registry, looking for consistency problems.
167
168        Outputs error messages at the end."""
169        # Iterate through commands, looking for consistency problems.
170        for name, info in self.reg.cmddict.items():
171            self.set_error_context(entity=name, elem=info.elem)
172
173            self.check_command(name, info)
174
175        for name, info in self.reg.typedict.items():
176            cat = info.elem.get('category')
177            if not cat:
178                # This is an external thing, skip it.
179                continue
180            self.set_error_context(entity=name, elem=info.elem)
181
182            self.check_type(name, info, cat)
183
184        # check_extension is called for all extensions, even 'disabled'
185        # ones, but some checks may be skipped depending on extension
186        # status.
187        for name, info in self.reg.extdict.items():
188            self.set_error_context(entity=name, elem=info.elem)
189            self.check_extension(name, info)
190
191        self.check_format()
192
193        entities_with_messages = set(
194            self.errors.keys()).union(self.warnings.keys())
195        if entities_with_messages:
196            print('xml_consistency/consistency_tools error and warning messages follow.')
197
198        for entity in entities_with_messages:
199            print()
200            print('-------------------')
201            print('Messages for', entity)
202            print()
203            messages = self.errors.get(entity)
204            if messages:
205                for m in messages:
206                    print('Error:', m)
207
208            messages = self.warnings.get(entity)
209            if messages:
210                for m in messages:
211                    print('Warning:', m)
212
213    def check_param(self, param):
214        """Check a member of a struct or a param of a function.
215
216        Called from check_params.
217
218        May extend."""
219        param_name = getElemName(param)
220        externsyncs = ExternSyncEntry.parse_externsync_from_param(param)
221        if externsyncs:
222            for entry in externsyncs:
223                if entry.entirely_extern_sync:
224                    if len(externsyncs) > 1:
225                        self.record_error("Comma-separated list in externsync attribute includes 'true' for",
226                                          param_name)
227                else:
228                    # member name
229                    # TODO only looking at the superficial feature here,
230                    # not entry.param_ref_parts
231                    if entry.member != param_name:
232                        self.record_error("externsync attribute for", param_name,
233                                          "refers to some other member/parameter:", entry.member)
234
235    def check_params(self, params):
236        """Check the members of a struct or params of a function.
237
238        Called from check_type and check_command.
239
240        May extend."""
241        for param in params:
242            self.check_param(param)
243
244            # Check for parameters referenced by len= attribute
245            lengths = LengthEntry.parse_len_from_param(param)
246            if lengths:
247                for entry in lengths:
248                    if not entry.other_param_name:
249                        continue
250                    # TODO only looking at the superficial feature here,
251                    # not entry.param_ref_parts
252                    other_param = findNamedElem(params, entry.other_param_name)
253                    if other_param is None:
254                        self.record_error("References a non-existent parameter/member in the length of",
255                                          getElemName(param), ":", entry.other_param_name)
256
257    def check_type(self, name, info, category):
258        """Check a type's XML data for consistency.
259
260        Called from check.
261
262        May extend."""
263        if category == 'struct':
264            if not name.startswith(self.conventions.type_prefix):
265                self.record_error("Name does not start with",
266                                  self.conventions.type_prefix)
267            members = info.elem.findall('member')
268            self.check_params(members)
269
270            # Check the structure type member, if present.
271            type_member = findNamedElem(
272                members, self.conventions.structtype_member_name)
273            if type_member is not None:
274                val = type_member.get('values')
275                if val:
276                    expected = self.conventions.generate_structure_type_from_name(
277                        name)
278                    if val != expected:
279                        self.record_error("Type has incorrect type-member value: expected",
280                                          expected, "got", val)
281
282        elif category == "bitmask":
283            if 'Flags' not in name:
284                self.record_error("Name of bitmask doesn't include 'Flags'")
285
286    def check_extension(self, name, info):
287        """Check an extension's XML data for consistency.
288
289        Called from check.
290
291        May extend."""
292        pass
293
294    def check_format(self):
295        """Check an extension's XML data for consistency.
296
297        Called from check.
298
299        May extend."""
300        pass
301
302    def check_command(self, name, info):
303        """Check a command's XML data for consistency.
304
305        Called from check.
306
307        May extend."""
308        elem = info.elem
309
310        self.check_params(elem.findall('param'))
311
312        # Some minimal return code checking
313        errorcodes = elem.get("errorcodes")
314        if errorcodes:
315            errorcodes = errorcodes.split(",")
316        else:
317            errorcodes = []
318
319        successcodes = elem.get("successcodes")
320        if successcodes:
321            successcodes = successcodes.split(",")
322        else:
323            successcodes = []
324
325        if not successcodes and not errorcodes:
326            # Early out if no return codes.
327            return
328
329        # Create a set for each group of codes, and check that
330        # they aren't duplicated within or between groups.
331        errorcodes_set = set(errorcodes)
332        if len(errorcodes) != len(errorcodes_set):
333            self.record_error("Contains a duplicate in errorcodes")
334
335        successcodes_set = set(successcodes)
336        if len(successcodes) != len(successcodes_set):
337            self.record_error("Contains a duplicate in successcodes")
338
339        if not successcodes_set.isdisjoint(errorcodes_set):
340            self.record_error("Has errorcodes and successcodes that overlap")
341
342        self.check_command_return_codes_basic(
343            name, info, successcodes_set, errorcodes_set)
344
345        # Continue to further return code checking if not "complicated"
346        if not self.should_skip_checking_codes(name):
347            codes_set = successcodes_set.union(errorcodes_set)
348            self.check_command_return_codes(
349                name, info, successcodes_set, errorcodes_set, codes_set)
350
351    def check_command_return_codes_basic(self, name, info,
352                                         successcodes, errorcodes):
353        """Check a command's return codes for consistency.
354
355        Called from check_command on every command.
356
357        May extend."""
358
359        # Check that all error codes include _ERROR_,
360        #  and that no success codes do.
361        for code in errorcodes:
362            if "_ERROR_" not in code:
363                self.record_error(
364                    code, "in errorcodes but doesn't contain _ERROR_")
365
366        for code in successcodes:
367            if "_ERROR_" in code:
368                self.record_error(code, "in successcodes but contain _ERROR_")
369
370    def check_command_return_codes(self, name, type_info,
371                                   successcodes, errorcodes,
372                                   codes):
373        """Check a command's return codes in-depth for consistency.
374
375        Called from check_command, only if
376        `self.should_skip_checking_codes(name)` is False.
377
378        May extend."""
379        referenced_input = self.referenced_input_types[name]
380        referenced_types = self.referenced_api_types[name]
381
382        # Check that we have all the codes we expect, based on input types.
383        for referenced_type in referenced_input:
384            required_codes = self.get_codes_for_command_and_type(
385                name, referenced_type)
386            missing_codes = required_codes - codes
387            if missing_codes:
388                path = self.referenced_input_types.shortest_path(
389                    name, referenced_type)
390                path_str = " -> ".join(path)
391                self.record_error("Missing expected return code(s)",
392                                  ",".join(missing_codes),
393                                  "implied because of input of type",
394                                  referenced_type,
395                                  "found via path",
396                                  path_str)
397
398        # Check that, for each code returned by this command that we can
399        # associate with a type, we have some type that can provide it.
400        # e.g. can't have INSTANCE_LOST without an Instance
401        # (or child of Instance).
402        for code in codes:
403
404            required_types = self.codes_requiring_input_type.get(code)
405            if not required_types:
406                # This code doesn't have a known requirement
407                continue
408
409            # TODO: do we look at referenced_types or referenced_input here?
410            # the latter is stricter
411            if not referenced_types.intersection(required_types):
412                self.record_error("Unexpected return code", code,
413                                  "- none of these types:",
414                                  required_types,
415                                  "found in the set of referenced types",
416                                  referenced_types)
417
418    ###
419    # Utility properties/methods
420    ###
421
422    def set_error_context(self, entity=None, elem=None):
423        """Set the entity and/or element for future record_error calls."""
424        self.entity = entity
425        self.elem = elem
426        self.name = getElemName(elem)
427        self.entity_suppressions = self.suppressions.get(getElemName(elem))
428
429    def record_error(self, *args, **kwargs):
430        """Record failure and an error message for the current context."""
431        message = " ".join((str(x) for x in args))
432
433        if self._is_message_suppressed(message):
434            return
435
436        message = self._prepend_sourceline_to_message(message, **kwargs)
437        self.fail = True
438        self.errors.add(self.entity, message)
439
440    def record_warning(self, *args, **kwargs):
441        """Record a warning message for the current context."""
442        message = " ".join((str(x) for x in args))
443
444        if self._is_message_suppressed(message):
445            return
446
447        message = self._prepend_sourceline_to_message(message, **kwargs)
448        self.warnings.add(self.entity, message)
449
450    def _is_message_suppressed(self, message):
451        """Return True if the given message, for this entity, should be suppressed."""
452        if not self.entity_suppressions:
453            return False
454        for suppress in self.entity_suppressions:
455            if suppress in message:
456                return True
457
458        return False
459
460    def _prepend_sourceline_to_message(self, message, **kwargs):
461        """Prepend a file and/or line reference to the message, if possible.
462
463        If filename is given as a keyword argument, it is used on its own.
464
465        If filename is not given, this will attempt to retrieve the filename and line from an XML element.
466        If 'elem' is given as a keyword argument and is not None, it is used to find the line.
467        If 'elem' is given as None, no XML elements are looked at.
468        If 'elem' is not supplied, the error context element is used.
469
470        If using XML, the filename, if available, is retrieved from the Registry class.
471        If using XML and python-lxml is installed, the source line is retrieved from whatever element is chosen."""
472        fn = kwargs.get('filename')
473        sourceline = None
474
475        if fn is None:
476            elem = kwargs.get('elem', self.elem)
477            if elem is not None:
478                sourceline = getattr(elem, 'sourceline', None)
479                if self.reg.filename:
480                    fn = self.reg.filename
481
482        if fn is None and sourceline is None:
483            return message
484
485        if fn is None:
486            return "Line {}: {}".format(sourceline, message)
487
488        if sourceline is None:
489            return "{}: {}".format(fn, message)
490
491        return "{}:{}: {}".format(fn, sourceline, message)
492
493
494class HandleParents(RecursiveMemoize):
495    def __init__(self, handle_types):
496        self.handle_types = handle_types
497
498        def compute(handle_type):
499            immediate_parent = self.handle_types[handle_type].elem.get(
500                'parent')
501
502            if immediate_parent is None:
503                # No parents, no need to recurse
504                return []
505
506            # Support multiple (alternate) parents
507            immediate_parents = immediate_parent.split(',')
508
509            # Recurse, combine, and return
510            all_parents = immediate_parents[:]
511            for parent in immediate_parents:
512                all_parents.extend(self[parent])
513            return all_parents
514
515        super().__init__(compute, handle_types.keys())
516
517
518def _always_true(x):
519    return True
520
521
522class ReferencedTypes(RecursiveMemoize):
523    """Find all types(optionally matching a predicate) that are referenced
524    by a struct or function, recursively."""
525
526    def __init__(self, db, predicate=None):
527        """Initialize.
528
529        Provide an EntityDB object and a predicate function."""
530        self.db = db
531
532        self.predicate = predicate
533        if not self.predicate:
534            # Default predicate is "anything goes"
535            self.predicate = _always_true
536
537        self._directly_referenced = {}
538        self.graph = nx.DiGraph()
539
540        def compute(type_name):
541            """Compute and return all types referenced by type_name, recursively, that satisfy the predicate.
542
543            Called by the [] operator in the base class."""
544            types = self.directly_referenced(type_name)
545            if not types:
546                return types
547
548            all_types = set()
549            all_types.update(types)
550            for t in types:
551                referenced = self[t]
552                if referenced is not None:
553                    # If not leading to a cycle
554                    all_types.update(referenced)
555            return all_types
556
557        # Initialize base class
558        super().__init__(compute, permit_cycles=True)
559
560    def shortest_path(self, source, target):
561        """Get the shortest path between one type/function name and another."""
562        # Trigger computation
563        _ = self[source]
564
565        return nx.algorithms.shortest_path(self.graph, source=source, target=target)
566
567    def directly_referenced(self, type_name):
568        """Get all types referenced directly by type_name that satisfy the predicate.
569
570        Memoizes its results."""
571        if type_name not in self._directly_referenced:
572            members = self.db.getMemberElems(type_name)
573            if members:
574                types = ((member, member.find("type")) for member in members)
575                self._directly_referenced[type_name] = set(type_elem.text for (member, type_elem) in types
576                                                           if type_elem is not None and self.predicate(member))
577
578            else:
579                self._directly_referenced[type_name] = set()
580
581            # Update graph
582            self.graph.add_node(type_name)
583            self.graph.add_edges_from((type_name, t)
584                                      for t in self._directly_referenced[type_name])
585
586        return self._directly_referenced[type_name]
587
588
589class HandleData:
590    """Data about all the handle types available in an API specification."""
591
592    def __init__(self, registry):
593        self.reg = registry
594        self._handle_types = None
595        self._ancestors = None
596        self._descendants = None
597
598    @property
599    def handle_types(self):
600        """Return a dictionary of handle type names to type info."""
601        if not self._handle_types:
602            # First time requested - compute it.
603            self._handle_types = {
604                type_name: type_info
605                for type_name, type_info in self.reg.typedict.items()
606                if type_info.elem.get('category') == 'handle'
607            }
608        return self._handle_types
609
610    @property
611    def ancestors_dict(self):
612        """Return a dictionary of handle type names to sets of ancestors."""
613        if not self._ancestors:
614            # First time requested - compute it.
615            self._ancestors = HandleParents(self.handle_types).get_dict()
616        return self._ancestors
617
618    @property
619    def descendants_dict(self):
620        """Return a dictionary of handle type names to sets of descendants."""
621        if not self._descendants:
622            # First time requested - compute it.
623
624            handle_parents = self.ancestors_dict
625
626            def get_descendants(handle):
627                return set(h for h in handle_parents.keys()
628                           if handle in handle_parents[h])
629
630            self._descendants = {
631                h: get_descendants(h)
632                for h in handle_parents.keys()
633            }
634        return self._descendants
635
636
637def compute_type_to_codes(handle_data, types_to_codes, extra_op=None):
638    """Compute a DictOfStringSets of input type to required return codes.
639
640    - handle_data is a HandleData instance.
641    - d is a dictionary of type names to strings or string collections of
642      return codes.
643    - extra_op, if any, is called after populating the output from the input
644      dictionary, but before propagation of parent codes to child types.
645      extra_op is called with the in-progress DictOfStringSets.
646
647    Returns a DictOfStringSets of input type name to set of required return
648    code names.
649    """
650    # Initialize with the supplied "manual" codes
651    types_to_codes = DictOfStringSets(types_to_codes)
652
653    # Dynamically generate more codes, if desired
654    if extra_op:
655        extra_op(types_to_codes)
656
657    # Final post-processing
658
659    # Any handle can result in its parent handle's codes too.
660
661    handle_ancestors = handle_data.ancestors_dict
662
663    extra_handle_codes = {}
664    for handle_type, ancestors in handle_ancestors.items():
665        codes = set()
666        # The sets of return codes corresponding to each ancestor type.
667        ancestors_codes = (types_to_codes.get(ancestor, set())
668                           for ancestor in ancestors)
669        codes.union(*ancestors_codes)
670        # for parent_codes in ancestors_codes:
671        #     codes.update(parent_codes)
672        extra_handle_codes[handle_type] = codes
673
674    for handle_type, extras in extra_handle_codes.items():
675        types_to_codes.add(handle_type, extras)
676
677    return types_to_codes
678
679
680def compute_codes_requiring_type(handle_data, types_to_codes, registry=None):
681    """Compute a DictOfStringSets of return codes to a set of input types able
682    to provide the ability to generate that code.
683
684    handle_data is a HandleData instance.
685    d is a dictionary of input types to associated return codes(same format
686    as for input to compute_type_to_codes, may use same dict).
687    This will invert that relationship, and also permit any "child handles"
688    to satisfy a requirement for a parent in producing a code.
689
690    Returns a DictOfStringSets of return code name to the set of parameter
691    types that would allow that return code.
692    """
693    # Use DictOfStringSets to normalize the input into a dict with values
694    # that are sets of strings
695    in_dict = DictOfStringSets(types_to_codes)
696
697    handle_descendants = handle_data.descendants_dict
698
699    out = DictOfStringSets()
700    for in_type, code_set in in_dict.items():
701        descendants = handle_descendants.get(in_type)
702        for code in code_set:
703            out.add(code, in_type)
704            if descendants:
705                out.add(code, descendants)
706
707    return out
708