• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""Provides EntityDatabase, a class that keeps track of spec-defined entities and associated macros."""
2
3# Copyright (c) 2018-2019 Collabora, Ltd.
4#
5# SPDX-License-Identifier: Apache-2.0
6#
7# Author(s):    Ryan Pavlik <ryan.pavlik@collabora.com>
8
9from abc import ABC, abstractmethod
10
11from .shared import (CATEGORIES_WITH_VALIDITY, EXTENSION_CATEGORY,
12                     NON_EXISTENT_MACROS, EntityData)
13from .util import getElemName
14
15
16def _entityToDict(data):
17    return {
18        'macro': data.macro,
19        'filename': data.filename,
20        'category': data.category,
21        'directory': data.directory
22    }
23
24
25class EntityDatabase(ABC):
26    """Parsed and processed information from the registry XML.
27
28    Must be subclasses for each specific API.
29    """
30
31    ###
32    # Methods that must be implemented in subclasses.
33    ###
34    @abstractmethod
35    def makeRegistry(self):
36        """Return a Registry object that has already had loadFile() and parseTree() called.
37
38        Called only once during construction.
39        """
40        raise NotImplementedError
41
42    @abstractmethod
43    def getNamePrefix(self):
44        """Return the (two-letter) prefix of all entity names for this API.
45
46        Called only once during construction.
47        """
48        raise NotImplementedError
49
50    @abstractmethod
51    def getPlatformRequires(self):
52        """Return the 'requires' string associated with external/platform definitions.
53
54        This is the string found in the requires attribute of the XML for entities that
55        are externally defined in a platform include file, like the question marks in:
56
57        <type requires="???" name="int8_t"/>
58
59        In Vulkan, this is 'vk_platform'.
60
61        Called only once during construction.
62        """
63        raise NotImplementedError
64
65    ###
66    # Methods that it is optional to **override**
67    ###
68    def getSystemTypes(self):
69        """Return an enumerable of strings that name system types.
70
71        System types use the macro `code`, and they do not generate API/validity includes.
72
73        Called only once during construction.
74        """
75        return []
76
77    def getGeneratedDirs(self):
78        """Return a sequence of strings that are the subdirectories of generates API includes.
79
80        Called only once during construction.
81        """
82        return ['basetypes',
83                'defines',
84                'enums',
85                'flags',
86                'funcpointers',
87                'handles',
88                'protos',
89                'structs']
90
91    def populateMacros(self):
92        """Perform API-specific calls, if any, to self.addMacro() and self.addMacros().
93
94        It is recommended to implement/override this and call
95        self.addMacros(..., ..., [..., "flags"]),
96        since the base implementation, in _basicPopulateMacros(),
97        does not add any macros as pertaining to the category "flags".
98
99        Called only once during construction.
100        """
101        pass
102
103    def populateEntities(self):
104        """Perform API-specific calls, if any, to self.addEntity()."""
105        pass
106
107    def getEntitiesWithoutValidity(self):
108        """Return an enumerable of entity names that do not generate validity includes."""
109        return [self.mixed_case_name_prefix +
110                x for x in ['BaseInStructure', 'BaseOutStructure']]
111
112    def getExclusionSet(self):
113        """Return a set of "support=" attribute strings that should not be included in the database.
114
115        Called only during construction."""
116        return set(('disabled',))
117
118    ###
119    # Methods that it is optional to **extend**
120    ###
121    def handleType(self, name, info, requires):
122        """Add entities, if appropriate, for an item in registry.typedict.
123
124        Called at construction for every name, info in registry.typedict.items()
125        not immediately skipped,
126        to perform the correct associated addEntity() call, if applicable.
127        The contents of the requires attribute, if any, is passed in requires.
128
129        May be extended by API-specific code to handle some cases preferentially,
130        then calling the super implementation to handle the rest.
131        """
132        if requires == self.platform_requires:
133            # Ah, no, don't skip this, it's just in the platform header file.
134            # TODO are these code or basetype?
135            self.addEntity(name, 'code', elem=info.elem, generates=False)
136            return
137
138        protect = info.elem.get('protect')
139        if protect:
140            self.addEntity(protect, 'dlink',
141                           category='configdefines', generates=False)
142
143        alias = info.elem.get('alias')
144        if alias:
145            self.addAlias(name, alias)
146
147        cat = info.elem.get('category')
148        if cat == 'struct':
149            self.addEntity(name, 'slink', elem=info.elem)
150
151        elif cat == 'union':
152            # TODO: is this right?
153            self.addEntity(name, 'slink', elem=info.elem)
154
155        elif cat == 'enum':
156            self.addEntity(
157                name, 'elink', elem=info.elem)
158
159        elif cat == 'handle':
160            self.addEntity(name, 'slink', elem=info.elem,
161                           category='handles')
162
163        elif cat == 'bitmask':
164            self.addEntity(
165                name, 'tlink', elem=info.elem, category='flags')
166
167        elif cat == 'basetype':
168            self.addEntity(name, 'basetype',
169                           elem=info.elem)
170
171        elif cat == 'define':
172            self.addEntity(name, 'dlink', elem=info.elem)
173
174        elif cat == 'funcpointer':
175            self.addEntity(name, 'tlink', elem=info.elem)
176
177        elif cat == 'include':
178            # skip
179            return
180
181        elif cat is None:
182            self.addEntity(name, 'code', elem=info.elem, generates=False)
183
184        else:
185            raise RuntimeError('unrecognized category {}'.format(cat))
186
187    def handleCommand(self, name, info):
188        """Add entities, if appropriate, for an item in registry.cmddict.
189
190        Called at construction for every name, info in registry.cmddict.items().
191        Calls self.addEntity() accordingly.
192        """
193        self.addEntity(name, 'flink', elem=info.elem,
194                       category='commands', directory='protos')
195
196    def handleExtension(self, name, info):
197        """Add entities, if appropriate, for an item in registry.extdict.
198
199        Called at construction for every name, info in registry.extdict.items().
200        Calls self.addEntity() accordingly.
201        """
202        if info.supported in self._supportExclusionSet:
203            # Don't populate with disabled extensions.
204            return
205
206        # Only get the protect strings and name from extensions
207
208        self.addEntity(name, None, category=EXTENSION_CATEGORY,
209                       generates=False)
210        protect = info.elem.get('protect')
211        if protect:
212            self.addEntity(protect, 'dlink',
213                           category='configdefines', generates=False)
214
215    def handleEnumValue(self, name, info):
216        """Add entities, if appropriate, for an item in registry.enumdict.
217
218        Called at construction for every name, info in registry.enumdict.items().
219        Calls self.addEntity() accordingly.
220        """
221        self.addEntity(name, 'ename', elem=info.elem,
222                       category='enumvalues', generates=False)
223
224    ###
225    # END of methods intended to be implemented, overridden, or extended in child classes!
226    ###
227
228    ###
229    # Accessors
230    ###
231    def findMacroAndEntity(self, macro, entity):
232        """Look up EntityData by macro and entity pair.
233
234        Does **not** resolve aliases."""
235        return self._byMacroAndEntity.get((macro, entity))
236
237    def findEntity(self, entity):
238        """Look up EntityData by entity name (case-sensitive).
239
240        If it fails, it will try resolving aliases.
241        """
242        result = self._byEntity.get(entity)
243        if result:
244            return result
245
246        alias_set = self._aliasSetsByEntity.get(entity)
247        if alias_set:
248            for alias in alias_set:
249                if alias in self._byEntity:
250                    return self.findEntity(alias)
251
252            assert(not "Alias without main entry!")
253
254        return None
255
256    def findEntityCaseInsensitive(self, entity):
257        """Look up EntityData by entity name (case-insensitive).
258
259        Does **not** resolve aliases."""
260        return self._byLowercaseEntity.get(entity.lower())
261
262    def getMemberElems(self, commandOrStruct):
263        """Given a command or struct name, retrieve the ETree elements for each member/param.
264
265        Returns None if the entity is not found or doesn't have members/params.
266        """
267        data = self.findEntity(commandOrStruct)
268
269        if not data:
270            return None
271        if data.elem is None:
272            return None
273        if data.macro == 'slink':
274            tag = 'member'
275        else:
276            tag = 'param'
277        return data.elem.findall('.//{}'.format(tag))
278
279    def getMemberNames(self, commandOrStruct):
280        """Given a command or struct name, retrieve the names of each member/param.
281
282        Returns an empty list if the entity is not found or doesn't have members/params.
283        """
284        members = self.getMemberElems(commandOrStruct)
285        if not members:
286            return []
287        ret = []
288        for member in members:
289            name_tag = member.find('name')
290            if name_tag:
291                ret.append(name_tag.text)
292        return ret
293
294    def getEntityJson(self):
295        """Dump the internal entity dictionary to JSON for debugging."""
296        import json
297        d = {entity: _entityToDict(data)
298             for entity, data in self._byEntity.items()}
299        return json.dumps(d, sort_keys=True, indent=4)
300
301    def entityHasValidity(self, entity):
302        """Estimate if we expect to see a validity include for an entity name.
303
304        Returns None if the entity name is not known,
305        otherwise a boolean: True if a validity include is expected.
306
307        Related to Generator.isStructAlwaysValid.
308        """
309        data = self.findEntity(entity)
310        if not data:
311            return None
312
313        if entity in self.entities_without_validity:
314            return False
315
316        if data.category == 'protos':
317            # All protos have validity
318            return True
319
320        if data.category not in CATEGORIES_WITH_VALIDITY:
321            return False
322
323        # Handle structs here.
324        members = self.getMemberElems(entity)
325        if not members:
326            return None
327        for member in members:
328            member_name = getElemName(member)
329            member_type = member.find('type').text
330            member_category = member.get('category')
331
332            if member_name in ('next', 'type'):
333                return True
334
335            if member_type in ('void', 'char'):
336                return True
337
338            if member.get('noautovalidity'):
339                # Not generating validity for this member, skip it
340                continue
341
342            if member.get('len'):
343                # Array
344                return True
345
346            typetail = member.find('type').tail
347            if typetail and '*' in typetail:
348                # Pointer
349                return True
350
351            if member_category in ('handle', 'enum', 'bitmask'):
352                return True
353
354            if member.get('category') in ('struct', 'union') \
355                    and self.entityHasValidity(member_type):
356                # struct or union member - recurse
357                return True
358
359        # Got this far - no validity needed
360        return False
361
362    def entityGenerates(self, entity_name):
363        """Return True if the named entity generates include file(s)."""
364        return entity_name in self._generating_entities
365
366    @property
367    def generating_entities(self):
368        """Return a sequence of all generating entity names."""
369        return self._generating_entities.keys()
370
371    def shouldBeRecognized(self, macro, entity_name):
372        """Determine, based on the macro and the name provided, if we should expect to recognize the entity.
373
374        True if it is linked. Specific APIs may also provide additional cases where it is True."""
375        return self.isLinkedMacro(macro)
376
377    def likelyRecognizedEntity(self, entity_name):
378        """Guess (based on name prefix alone) if an entity is likely to be recognized."""
379        return entity_name.lower().startswith(self.name_prefix)
380
381    def isLinkedMacro(self, macro):
382        """Identify if a macro is considered a "linked" macro."""
383        return macro in self._linkedMacros
384
385    def isValidMacro(self, macro):
386        """Identify if a macro is known and valid."""
387        if macro not in self._categoriesByMacro:
388            return False
389
390        return macro not in NON_EXISTENT_MACROS
391
392    def getCategoriesForMacro(self, macro):
393        """Identify the categories associated with a (known, valid) macro."""
394        if macro in self._categoriesByMacro:
395            return self._categoriesByMacro[macro]
396        return None
397
398    def areAliases(self, first_entity_name, second_entity_name):
399        """Return true if the two entity names are equivalent (aliases of each other)."""
400        alias_set = self._aliasSetsByEntity.get(first_entity_name)
401        if not alias_set:
402            # If this assert fails, we have goofed in addAlias
403            assert(second_entity_name not in self._aliasSetsByEntity)
404
405            return False
406
407        return second_entity_name in alias_set
408
409    @property
410    def macros(self):
411        """Return the collection of all known entity-related markup macros."""
412        return self._categoriesByMacro.keys()
413
414    ###
415    # Methods only used during initial setup/population of this data structure
416    ###
417    def addMacro(self, macro, categories, link=False):
418        """Add a single markup macro to the collection of categories by macro.
419
420        Also adds the macro to the set of linked macros if link=True.
421
422        If a macro has already been supplied to a call, later calls for that macro have no effect.
423        """
424        if macro in self._categoriesByMacro:
425            return
426        self._categoriesByMacro[macro] = categories
427        if link:
428            self._linkedMacros.add(macro)
429
430    def addMacros(self, letter, macroTypes, categories):
431        """Add markup macros associated with a leading letter to the collection of categories by macro.
432
433        Also, those macros created using 'link' in macroTypes will also be added to the set of linked macros.
434
435        Basically automates a number of calls to addMacro().
436        """
437        for macroType in macroTypes:
438            macro = letter + macroType
439            self.addMacro(macro, categories, link=(macroType == 'link'))
440
441    def addAlias(self, entityName, aliasName):
442        """Record that entityName is an alias for aliasName."""
443        # See if we already have something with this as the alias.
444        alias_set = self._aliasSetsByEntity.get(aliasName)
445        other_alias_set = self._aliasSetsByEntity.get(entityName)
446        if alias_set and other_alias_set:
447            # If this fails, we need to merge sets and update.
448            assert(alias_set is other_alias_set)
449
450        if not alias_set:
451            # Try looking by the other name.
452            alias_set = other_alias_set
453
454        if not alias_set:
455            # Nope, this is a new set.
456            alias_set = set()
457            self._aliasSets.append(alias_set)
458
459        # Add both names to the set
460        alias_set.add(entityName)
461        alias_set.add(aliasName)
462
463        # Associate the set with each name
464        self._aliasSetsByEntity[aliasName] = alias_set
465        self._aliasSetsByEntity[entityName] = alias_set
466
467    def addEntity(self, entityName, macro, category=None, elem=None,
468                  generates=None, directory=None, filename=None):
469        """Add an entity (command, structure type, enum, enum value, etc) in the database.
470
471        If an entityName has already been supplied to a call, later calls for that entityName have no effect.
472
473        Arguments:
474        entityName -- the name of the entity.
475        macro -- the macro (without the trailing colon) that should be used to refer to this entity.
476
477        Optional keyword arguments:
478        category -- If not manually specified, looked up based on the macro.
479        elem -- The ETree element associated with the entity in the registry XML.
480        generates -- Indicates whether this entity generates api and validity include files.
481                     Default depends on directory (or if not specified, category).
482        directory -- The directory that include files (under api/ and validity/) are generated in.
483                     If not specified (and generates is True), the default is the same as the category,
484                     which is almost always correct.
485        filename -- The relative filename (under api/ or validity/) where includes are generated for this.
486                    This only matters if generates is True (default). If not specified and generates is True,
487                    one will be generated based on directory and entityName.
488        """
489        # Probably dealt with in handleType(), but just in case it wasn't.
490        if elem is not None:
491            alias = elem.get('alias')
492            if alias:
493                self.addAlias(entityName, alias)
494
495        if entityName in self._byEntity:
496            # skip if already recorded.
497            return
498
499        # Look up category based on the macro, if category isn't specified.
500        if category is None:
501            category = self._categoriesByMacro.get(macro)[0]
502
503        if generates is None:
504            potential_dir = directory or category
505            generates = potential_dir in self._generated_dirs
506
507        # If directory isn't specified and this entity generates,
508        # the directory is the same as the category.
509        if directory is None and generates:
510            directory = category
511
512        # Don't generate a filename if this entity doesn't generate includes.
513        if filename is None and generates:
514            filename = '{}/{}.txt'.format(directory, entityName)
515
516        data = EntityData(
517            entity=entityName,
518            macro=macro,
519            elem=elem,
520            filename=filename,
521            category=category,
522            directory=directory
523        )
524        if entityName.lower() not in self._byLowercaseEntity:
525            self._byLowercaseEntity[entityName.lower()] = []
526
527        self._byEntity[entityName] = data
528        self._byLowercaseEntity[entityName.lower()].append(data)
529        self._byMacroAndEntity[(macro, entityName)] = data
530        if generates and filename is not None:
531            self._generating_entities[entityName] = data
532
533    def __init__(self):
534        """Constructor: Do not extend or override.
535
536        Changing the behavior of other parts of this logic should be done by
537        implementing, extending, or overriding (as documented):
538
539        - Implement makeRegistry()
540        - Implement getNamePrefix()
541        - Implement getPlatformRequires()
542        - Override getSystemTypes()
543        - Override populateMacros()
544        - Override populateEntities()
545        - Extend handleType()
546        - Extend handleCommand()
547        - Extend handleExtension()
548        - Extend handleEnumValue()
549        """
550        # Internal data that we don't want consumers of the class touching for fear of
551        # breaking invariants
552        self._byEntity = {}
553        self._byLowercaseEntity = {}
554        self._byMacroAndEntity = {}
555        self._categoriesByMacro = {}
556        self._linkedMacros = set()
557        self._aliasSetsByEntity = {}
558        self._aliasSets = []
559
560        self._registry = None
561
562        # Retrieve from subclass, if overridden, then store locally.
563        self._supportExclusionSet = set(self.getExclusionSet())
564
565        # Entities that get a generated/api/category/entity.txt file.
566        self._generating_entities = {}
567
568        # Name prefix members
569        self.name_prefix = self.getNamePrefix().lower()
570        self.mixed_case_name_prefix = self.name_prefix[:1].upper(
571        ) + self.name_prefix[1:]
572        # Regex string for the name prefix that is case-insensitive.
573        self.case_insensitive_name_prefix_pattern = ''.join(
574            ('[{}{}]'.format(c.upper(), c) for c in self.name_prefix))
575
576        self.platform_requires = self.getPlatformRequires()
577
578        self._generated_dirs = set(self.getGeneratedDirs())
579
580        # Note: Default impl requires self.mixed_case_name_prefix
581        self.entities_without_validity = set(self.getEntitiesWithoutValidity())
582
583        # TODO: Where should flags actually go? Not mentioned in the style guide.
584        # TODO: What about flag wildcards? There are a few such uses...
585
586        # Abstract method: subclass must implement to define macros for flags
587        self.populateMacros()
588
589        # Now, do default macro population
590        self._basicPopulateMacros()
591
592        # Abstract method: subclass must implement to add any "not from the registry" (and not system type)
593        # entities
594        self.populateEntities()
595
596        # Now, do default entity population
597        self._basicPopulateEntities(self.registry)
598
599    ###
600    # Methods only used internally during initial setup/population of this data structure
601    ###
602    @property
603    def registry(self):
604        """Return a Registry."""
605        if not self._registry:
606            self._registry = self.makeRegistry()
607        return self._registry
608
609    def _basicPopulateMacros(self):
610        """Contains calls to self.addMacro() and self.addMacros().
611
612        If you need to change any of these, do so in your override of populateMacros(),
613        which will be called first.
614        """
615        self.addMacro('basetype', ['basetypes'])
616        self.addMacro('code', ['code'])
617        self.addMacros('f', ['link', 'name', 'text'], ['protos'])
618        self.addMacros('s', ['link', 'name', 'text'], ['structs', 'handles'])
619        self.addMacros('e', ['link', 'name', 'text'], ['enums'])
620        self.addMacros('p', ['name', 'text'], ['parameter', 'member'])
621        self.addMacros('t', ['link', 'name'], ['funcpointers'])
622        self.addMacros('d', ['link', 'name'], ['defines', 'configdefines'])
623
624        for macro in NON_EXISTENT_MACROS:
625            # Still search for them
626            self.addMacro(macro, None)
627
628    def _basicPopulateEntities(self, registry):
629        """Contains typical calls to self.addEntity().
630
631        If you need to change any of these, do so in your override of populateEntities(),
632        which will be called first.
633        """
634        system_types = set(self.getSystemTypes())
635        for t in system_types:
636            self.addEntity(t, 'code', generates=False)
637
638        for name, info in registry.typedict.items():
639            if name in system_types:
640                # We already added these.
641                continue
642
643            requires = info.elem.get('requires')
644
645            if requires and not requires.lower().startswith(self.name_prefix):
646                # This is an externally-defined type, will skip it.
647                continue
648
649            # OK, we might actually add an entity here
650            self.handleType(name=name, info=info, requires=requires)
651
652        for name, info in registry.enumdict.items():
653            self.handleEnumValue(name, info)
654
655        for name, info in registry.cmddict.items():
656            self.handleCommand(name, info)
657
658        for name, info in registry.extdict.items():
659            self.handleExtension(name, info)
660