• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3 -i
2#
3# Copyright (c) 2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8"""Provides utilities to write a script to verify XML registry consistency."""
9
10import re
11
12import networkx as nx
13
14from .algo import RecursiveMemoize
15from .attributes import ExternSyncEntry, LengthEntry
16from .data_structures import DictOfStringSets
17from .util import findNamedElem, getElemName
18
19
20class XMLChecker:
21    def __init__(self, entity_db,  conventions, manual_types_to_codes=None,
22                 forward_only_types_to_codes=None,
23                 reverse_only_types_to_codes=None,
24                 suppressions=None):
25        """Set up data structures.
26
27        May extend - call:
28        `super().__init__(db, conventions, manual_types_to_codes)`
29        as the last statement in your function.
30
31        manual_types_to_codes is a dictionary of hard-coded
32        "manual" return codes:
33        the codes of the value are available for a command if-and-only-if
34        the key type is passed as an input.
35
36        forward_only_types_to_codes is additional entries to the above
37        that should only be used in the "forward" direction
38        (arg type implies return code)
39
40        reverse_only_types_to_codes is additional entries to
41        manual_types_to_codes that should only be used in the
42        "reverse" direction
43        (return code implies arg type)
44        """
45        self.fail = False
46        self.entity = None
47        self.errors = DictOfStringSets()
48        self.warnings = DictOfStringSets()
49        self.db = entity_db
50        self.reg = entity_db.registry
51        self.handle_data = HandleData(self.reg)
52        self.conventions = conventions
53
54        self.CONST_RE = re.compile(r"\bconst\b")
55        self.ARRAY_RE = re.compile(r"\[[^]]+\]")
56
57        # Init memoized properties
58        self._handle_data = None
59
60        if not manual_types_to_codes:
61            manual_types_to_codes = {}
62        if not reverse_only_types_to_codes:
63            reverse_only_types_to_codes = {}
64        if not forward_only_types_to_codes:
65            forward_only_types_to_codes = {}
66
67        reverse_codes = DictOfStringSets(reverse_only_types_to_codes)
68        forward_codes = DictOfStringSets(forward_only_types_to_codes)
69        for k, v in manual_types_to_codes.items():
70            forward_codes.add(k, v)
71            reverse_codes.add(k, v)
72
73        self.forward_only_manual_types_to_codes = forward_codes.get_dict()
74        self.reverse_only_manual_types_to_codes = reverse_codes.get_dict()
75
76        # The presence of some types as input to a function imply the
77        # availability of some return codes.
78        self.input_type_to_codes = compute_type_to_codes(
79            self.handle_data,
80            forward_codes,
81            extra_op=self.add_extra_codes)
82
83        # Some return codes require a type (or its child) in the input.
84        self.codes_requiring_input_type = compute_codes_requiring_type(
85            self.handle_data,
86            reverse_codes
87        )
88
89        specified_codes = set(self.codes_requiring_input_type.keys())
90        for codes in self.forward_only_manual_types_to_codes.values():
91            specified_codes.update(codes)
92        for codes in self.reverse_only_manual_types_to_codes.values():
93            specified_codes.update(codes)
94        for codes in self.input_type_to_codes.values():
95            specified_codes.update(codes)
96
97        unrecognized = specified_codes - self.return_codes
98        if unrecognized:
99            raise RuntimeError("Return code mentioned in script that isn't in the registry: " +
100                               ', '.join(unrecognized))
101
102        self.referenced_input_types = ReferencedTypes(self.db, self.is_input)
103        self.referenced_api_types = ReferencedTypes(self.db, self.is_api_type)
104        if not suppressions:
105            suppressions = {}
106        self.suppressions = DictOfStringSets(suppressions)
107
108    def is_api_type(self, member_elem):
109        """Return true if the member/parameter ElementTree passed is from this API.
110
111        May override or extend."""
112        membertext = "".join(member_elem.itertext())
113
114        return self.conventions.type_prefix in membertext
115
116    def is_input(self, member_elem):
117        """Return true if the member/parameter ElementTree passed is
118        considered "input".
119
120        May override or extend."""
121        membertext = "".join(member_elem.itertext())
122
123        if self.conventions.type_prefix not in membertext:
124            return False
125
126        ret = True
127        # Const is always input.
128        if self.CONST_RE.search(membertext):
129            ret = True
130
131        # Arrays and pointers that aren't const are always output.
132        elif "*" in membertext:
133            ret = False
134        elif self.ARRAY_RE.search(membertext):
135            ret = False
136
137        return ret
138
139    def add_extra_codes(self, types_to_codes):
140        """Add any desired entries to the types-to-codes DictOfStringSets
141        before performing "ancestor propagation".
142
143        Passed to compute_type_to_codes as the extra_op.
144
145        May override."""
146        pass
147
148    def should_skip_checking_codes(self, name):
149        """Return True if more than the basic validation of return codes should
150        be skipped for a command.
151
152        May override."""
153
154        return self.conventions.should_skip_checking_codes
155
156    def get_codes_for_command_and_type(self, cmd_name, type_name):
157        """Return a set of error codes expected due to having
158        an input argument of type type_name.
159
160        The cmd_name is passed for use by extending methods.
161
162        May extend."""
163        return self.input_type_to_codes.get(type_name, set())
164
165    def check(self):
166        """Iterate through the registry, looking for consistency problems.
167
168        Outputs error messages at the end."""
169        # Iterate through commands, looking for consistency problems.
170        for name, info in self.reg.cmddict.items():
171            self.set_error_context(entity=name, elem=info.elem)
172
173            self.check_command(name, info)
174
175        for name, info in self.reg.typedict.items():
176            cat = info.elem.get('category')
177            if not cat:
178                # This is an external thing, skip it.
179                continue
180            self.set_error_context(entity=name, elem=info.elem)
181
182            self.check_type(name, info, cat)
183
184        for name, info in self.reg.extdict.items():
185            if info.elem.get('supported') != self.conventions.xml_api_name:
186                # Skip unsupported extensions
187                continue
188            self.set_error_context(entity=name, elem=info.elem)
189            self.check_extension(name, info)
190
191        entities_with_messages = set(
192            self.errors.keys()).union(self.warnings.keys())
193        if entities_with_messages:
194            print('xml_consistency/consistency_tools error and warning messages follow.')
195
196        for entity in entities_with_messages:
197            print()
198            print('-------------------')
199            print('Messages for', entity)
200            print()
201            messages = self.errors.get(entity)
202            if messages:
203                for m in messages:
204                    print('Error:', m)
205
206            messages = self.warnings.get(entity)
207            if messages:
208                for m in messages:
209                    print('Warning:', m)
210
211    def check_param(self, param):
212        """Check a member of a struct or a param of a function.
213
214        Called from check_params.
215
216        May extend."""
217        param_name = getElemName(param)
218        externsyncs = ExternSyncEntry.parse_externsync_from_param(param)
219        if externsyncs:
220            for entry in externsyncs:
221                if entry.entirely_extern_sync:
222                    if len(externsyncs) > 1:
223                        self.record_error("Comma-separated list in externsync attribute includes 'true' for",
224                                          param_name)
225                else:
226                    # member name
227                    # TODO only looking at the superficial feature here,
228                    # not entry.param_ref_parts
229                    if entry.member != param_name:
230                        self.record_error("externsync attribute for", param_name,
231                                          "refers to some other member/parameter:", entry.member)
232
233    def check_params(self, params):
234        """Check the members of a struct or params of a function.
235
236        Called from check_type and check_command.
237
238        May extend."""
239        for param in params:
240            self.check_param(param)
241
242            # Check for parameters referenced by len= attribute
243            lengths = LengthEntry.parse_len_from_param(param)
244            if lengths:
245                for entry in lengths:
246                    if not entry.other_param_name:
247                        continue
248                    # TODO only looking at the superficial feature here,
249                    # not entry.param_ref_parts
250                    other_param = findNamedElem(params, entry.other_param_name)
251                    if other_param is None:
252                        self.record_error("References a non-existent parameter/member in the length of",
253                                          getElemName(param), ":", entry.other_param_name)
254
255    def check_type(self, name, info, category):
256        """Check a type's XML data for consistency.
257
258        Called from check.
259
260        May extend."""
261        if category == 'struct':
262            if not name.startswith(self.conventions.type_prefix):
263                self.record_error("Name does not start with",
264                                  self.conventions.type_prefix)
265            members = info.elem.findall('member')
266            self.check_params(members)
267
268            # Check the structure type member, if present.
269            type_member = findNamedElem(
270                members, self.conventions.structtype_member_name)
271            if type_member is not None:
272                val = type_member.get('values')
273                if val:
274                    expected = self.conventions.generate_structure_type_from_name(
275                        name)
276                    if val != expected:
277                        self.record_error("Type has incorrect type-member value: expected",
278                                          expected, "got", val)
279
280        elif category == "bitmask":
281            if 'Flags' not in name:
282                self.record_error("Name of bitmask doesn't include 'Flags'")
283
284    def check_extension(self, name, info):
285        """Check an extension's XML data for consistency.
286
287        Called from check.
288
289        May extend."""
290        pass
291
292    def check_command(self, name, info):
293        """Check a command's XML data for consistency.
294
295        Called from check.
296
297        May extend."""
298        elem = info.elem
299
300        self.check_params(elem.findall('param'))
301
302        # Some minimal return code checking
303        errorcodes = elem.get("errorcodes")
304        if errorcodes:
305            errorcodes = errorcodes.split(",")
306        else:
307            errorcodes = []
308
309        successcodes = elem.get("successcodes")
310        if successcodes:
311            successcodes = successcodes.split(",")
312        else:
313            successcodes = []
314
315        if not successcodes and not errorcodes:
316            # Early out if no return codes.
317            return
318
319        # Create a set for each group of codes, and check that
320        # they aren't duplicated within or between groups.
321        errorcodes_set = set(errorcodes)
322        if len(errorcodes) != len(errorcodes_set):
323            self.record_error("Contains a duplicate in errorcodes")
324
325        successcodes_set = set(successcodes)
326        if len(successcodes) != len(successcodes_set):
327            self.record_error("Contains a duplicate in successcodes")
328
329        if not successcodes_set.isdisjoint(errorcodes_set):
330            self.record_error("Has errorcodes and successcodes that overlap")
331
332        self.check_command_return_codes_basic(
333            name, info, successcodes_set, errorcodes_set)
334
335        # Continue to further return code checking if not "complicated"
336        if not self.should_skip_checking_codes(name):
337            codes_set = successcodes_set.union(errorcodes_set)
338            self.check_command_return_codes(
339                name, info, successcodes_set, errorcodes_set, codes_set)
340
341    def check_command_return_codes_basic(self, name, info,
342                                         successcodes, errorcodes):
343        """Check a command's return codes for consistency.
344
345        Called from check_command on every command.
346
347        May extend."""
348
349        # Check that all error codes include _ERROR_,
350        #  and that no success codes do.
351        for code in errorcodes:
352            if "_ERROR_" not in code:
353                self.record_error(
354                    code, "in errorcodes but doesn't contain _ERROR_")
355
356        for code in successcodes:
357            if "_ERROR_" in code:
358                self.record_error(code, "in successcodes but contain _ERROR_")
359
360    def check_command_return_codes(self, name, type_info,
361                                   successcodes, errorcodes,
362                                   codes):
363        """Check a command's return codes in-depth for consistency.
364
365        Called from check_command, only if
366        `self.should_skip_checking_codes(name)` is False.
367
368        May extend."""
369        referenced_input = self.referenced_input_types[name]
370        referenced_types = self.referenced_api_types[name]
371
372        # Check that we have all the codes we expect, based on input types.
373        for referenced_type in referenced_input:
374            required_codes = self.get_codes_for_command_and_type(
375                name, referenced_type)
376            missing_codes = required_codes - codes
377            if missing_codes:
378                path = self.referenced_input_types.shortest_path(
379                    name, referenced_type)
380                path_str = " -> ".join(path)
381                self.record_error("Missing expected return code(s)",
382                                  ",".join(missing_codes),
383                                  "implied because of input of type",
384                                  referenced_type,
385                                  "found via path",
386                                  path_str)
387
388        # Check that, for each code returned by this command that we can
389        # associate with a type, we have some type that can provide it.
390        # e.g. can't have INSTANCE_LOST without an Instance
391        # (or child of Instance).
392        for code in codes:
393
394            required_types = self.codes_requiring_input_type.get(code)
395            if not required_types:
396                # This code doesn't have a known requirement
397                continue
398
399            # TODO: do we look at referenced_types or referenced_input here?
400            # the latter is stricter
401            if not referenced_types.intersection(required_types):
402                self.record_error("Unexpected return code", code,
403                                  "- none of these types:",
404                                  required_types,
405                                  "found in the set of referenced types",
406                                  referenced_types)
407
408    ###
409    # Utility properties/methods
410    ###
411
412    def set_error_context(self, entity=None, elem=None):
413        """Set the entity and/or element for future record_error calls."""
414        self.entity = entity
415        self.elem = elem
416        self.name = getElemName(elem)
417        self.entity_suppressions = self.suppressions.get(getElemName(elem))
418
419    def record_error(self, *args, **kwargs):
420        """Record failure and an error message for the current context."""
421        message = " ".join((str(x) for x in args))
422
423        if self._is_message_suppressed(message):
424            return
425
426        message = self._prepend_sourceline_to_message(message, **kwargs)
427        self.fail = True
428        self.errors.add(self.entity, message)
429
430    def record_warning(self, *args, **kwargs):
431        """Record a warning message for the current context."""
432        message = " ".join((str(x) for x in args))
433
434        if self._is_message_suppressed(message):
435            return
436
437        message = self._prepend_sourceline_to_message(message, **kwargs)
438        self.warnings.add(self.entity, message)
439
440    def _is_message_suppressed(self, message):
441        """Return True if the given message, for this entity, should be suppressed."""
442        if not self.entity_suppressions:
443            return False
444        for suppress in self.entity_suppressions:
445            if suppress in message:
446                return True
447
448        return False
449
450    def _prepend_sourceline_to_message(self, message, **kwargs):
451        """Prepend a file and/or line reference to the message, if possible.
452
453        If filename is given as a keyword argument, it is used on its own.
454
455        If filename is not given, this will attempt to retrieve the filename and line from an XML element.
456        If 'elem' is given as a keyword argument and is not None, it is used to find the line.
457        If 'elem' is given as None, no XML elements are looked at.
458        If 'elem' is not supplied, the error context element is used.
459
460        If using XML, the filename, if available, is retrieved from the Registry class.
461        If using XML and python-lxml is installed, the source line is retrieved from whatever element is chosen."""
462        fn = kwargs.get('filename')
463        sourceline = None
464
465        if fn is None:
466            elem = kwargs.get('elem', self.elem)
467            if elem is not None:
468                sourceline = getattr(elem, 'sourceline', None)
469                if self.reg.filename:
470                    fn = self.reg.filename
471
472        if fn is None and sourceline is None:
473            return message
474
475        if fn is None:
476            return "Line {}: {}".format(sourceline, message)
477
478        if sourceline is None:
479            return "{}: {}".format(fn, message)
480
481        return "{}:{}: {}".format(fn, sourceline, message)
482
483
484class HandleParents(RecursiveMemoize):
485    def __init__(self, handle_types):
486        self.handle_types = handle_types
487
488        def compute(handle_type):
489            immediate_parent = self.handle_types[handle_type].elem.get(
490                'parent')
491
492            if immediate_parent is None:
493                # No parents, no need to recurse
494                return []
495
496            # Support multiple (alternate) parents
497            immediate_parents = immediate_parent.split(',')
498
499            # Recurse, combine, and return
500            all_parents = immediate_parents[:]
501            for parent in immediate_parents:
502                all_parents.extend(self[parent])
503            return all_parents
504
505        super().__init__(compute, handle_types.keys())
506
507
508def _always_true(x):
509    return True
510
511
512class ReferencedTypes(RecursiveMemoize):
513    """Find all types(optionally matching a predicate) that are referenced
514    by a struct or function, recursively."""
515
516    def __init__(self, db, predicate=None):
517        """Initialize.
518
519        Provide an EntityDB object and a predicate function."""
520        self.db = db
521
522        self.predicate = predicate
523        if not self.predicate:
524            # Default predicate is "anything goes"
525            self.predicate = _always_true
526
527        self._directly_referenced = {}
528        self.graph = nx.DiGraph()
529
530        def compute(type_name):
531            """Compute and return all types referenced by type_name, recursively, that satisfy the predicate.
532
533            Called by the [] operator in the base class."""
534            types = self.directly_referenced(type_name)
535            if not types:
536                return types
537
538            all_types = set()
539            all_types.update(types)
540            for t in types:
541                referenced = self[t]
542                if referenced is not None:
543                    # If not leading to a cycle
544                    all_types.update(referenced)
545            return all_types
546
547        # Initialize base class
548        super().__init__(compute, permit_cycles=True)
549
550    def shortest_path(self, source, target):
551        """Get the shortest path between one type/function name and another."""
552        # Trigger computation
553        _ = self[source]
554
555        return nx.algorithms.shortest_path(self.graph, source=source, target=target)
556
557    def directly_referenced(self, type_name):
558        """Get all types referenced directly by type_name that satisfy the predicate.
559
560        Memoizes its results."""
561        if type_name not in self._directly_referenced:
562            members = self.db.getMemberElems(type_name)
563            if members:
564                types = ((member, member.find("type")) for member in members)
565                self._directly_referenced[type_name] = set(type_elem.text for (member, type_elem) in types
566                                                           if type_elem is not None and self.predicate(member))
567
568            else:
569                self._directly_referenced[type_name] = set()
570
571            # Update graph
572            self.graph.add_node(type_name)
573            self.graph.add_edges_from((type_name, t)
574                                      for t in self._directly_referenced[type_name])
575
576        return self._directly_referenced[type_name]
577
578
579class HandleData:
580    """Data about all the handle types available in an API specification."""
581
582    def __init__(self, registry):
583        self.reg = registry
584        self._handle_types = None
585        self._ancestors = None
586        self._descendants = None
587
588    @property
589    def handle_types(self):
590        """Return a dictionary of handle type names to type info."""
591        if not self._handle_types:
592            # First time requested - compute it.
593            self._handle_types = {
594                type_name: type_info
595                for type_name, type_info in self.reg.typedict.items()
596                if type_info.elem.get('category') == 'handle'
597            }
598        return self._handle_types
599
600    @property
601    def ancestors_dict(self):
602        """Return a dictionary of handle type names to sets of ancestors."""
603        if not self._ancestors:
604            # First time requested - compute it.
605            self._ancestors = HandleParents(self.handle_types).get_dict()
606        return self._ancestors
607
608    @property
609    def descendants_dict(self):
610        """Return a dictionary of handle type names to sets of descendants."""
611        if not self._descendants:
612            # First time requested - compute it.
613
614            handle_parents = self.ancestors_dict
615
616            def get_descendants(handle):
617                return set(h for h in handle_parents.keys()
618                           if handle in handle_parents[h])
619
620            self._descendants = {
621                h: get_descendants(h)
622                for h in handle_parents.keys()
623            }
624        return self._descendants
625
626
627def compute_type_to_codes(handle_data, types_to_codes, extra_op=None):
628    """Compute a DictOfStringSets of input type to required return codes.
629
630    - handle_data is a HandleData instance.
631    - d is a dictionary of type names to strings or string collections of
632      return codes.
633    - extra_op, if any, is called after populating the output from the input
634      dictionary, but before propagation of parent codes to child types.
635      extra_op is called with the in-progress DictOfStringSets.
636
637    Returns a DictOfStringSets of input type name to set of required return
638    code names.
639    """
640    # Initialize with the supplied "manual" codes
641    types_to_codes = DictOfStringSets(types_to_codes)
642
643    # Dynamically generate more codes, if desired
644    if extra_op:
645        extra_op(types_to_codes)
646
647    # Final post-processing
648
649    # Any handle can result in its parent handle's codes too.
650
651    handle_ancestors = handle_data.ancestors_dict
652
653    extra_handle_codes = {}
654    for handle_type, ancestors in handle_ancestors.items():
655        codes = set()
656        # The sets of return codes corresponding to each ancestor type.
657        ancestors_codes = (types_to_codes.get(ancestor, set())
658                           for ancestor in ancestors)
659        codes.union(*ancestors_codes)
660        # for parent_codes in ancestors_codes:
661        #     codes.update(parent_codes)
662        extra_handle_codes[handle_type] = codes
663
664    for handle_type, extras in extra_handle_codes.items():
665        types_to_codes.add(handle_type, extras)
666
667    return types_to_codes
668
669
670def compute_codes_requiring_type(handle_data, types_to_codes, registry=None):
671    """Compute a DictOfStringSets of return codes to a set of input types able
672    to provide the ability to generate that code.
673
674    handle_data is a HandleData instance.
675    d is a dictionary of input types to associated return codes(same format
676    as for input to compute_type_to_codes, may use same dict).
677    This will invert that relationship, and also permit any "child handles"
678    to satisfy a requirement for a parent in producing a code.
679
680    Returns a DictOfStringSets of return code name to the set of parameter
681    types that would allow that return code.
682    """
683    # Use DictOfStringSets to normalize the input into a dict with values
684    # that are sets of strings
685    in_dict = DictOfStringSets(types_to_codes)
686
687    handle_descendants = handle_data.descendants_dict
688
689    out = DictOfStringSets()
690    for in_type, code_set in in_dict.items():
691        descendants = handle_descendants.get(in_type)
692        for code in code_set:
693            out.add(code, in_type)
694            if descendants:
695                out.add(code, descendants)
696
697    return out
698