• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2#
3# Copyright 2017-2022 The Khronos Group Inc.
4# SPDX-License-Identifier: Apache-2.0
5
6"""Generate a mapping of extension name -> all required extension names for
7   that extension, from dependencies in the API XML."""
8
9import argparse
10import errno
11import xml.etree.ElementTree as etree
12from pathlib import Path
13
14from apiconventions import APIConventions
15
16class DiGraph:
17    """A directed graph.
18
19    The implementation and API mimic that of networkx.DiGraph in networkx-1.11.
20    networkx implements graphs as nested dicts; it uses dicts all the way
21    down, no lists.
22
23    Some major differences between this implementation and that of
24    networkx-1.11 are:
25
26        * This omits edge and node attribute data, because we never use them
27          yet they add additional code complexity.
28
29        * This returns iterator objects when possible instead of collection
30          objects, because it simplifies the implementation and should provide
31          better performance.
32    """
33
34    def __init__(self):
35        self.__nodes = {}
36
37    def add_node(self, node):
38        if node not in self.__nodes:
39            self.__nodes[node] = DiGraphNode()
40
41    def add_edge(self, src, dest):
42        self.add_node(src)
43        self.add_node(dest)
44        self.__nodes[src].adj.add(dest)
45
46    def nodes(self):
47        """Iterate over the nodes in the graph."""
48        return self.__nodes.keys()
49
50    def descendants(self, node):
51        """
52        Iterate over the nodes reachable from the given start node, excluding
53        the start node itself. Each node in the graph is yielded at most once.
54        """
55
56        # Implementation detail: Do a breadth-first traversal because it is
57        # easier than depth-first.
58
59        # All nodes seen during traversal.
60        seen = set()
61
62        # The stack of nodes that need visiting.
63        visit_me = []
64
65        # Bootstrap the traversal.
66        seen.add(node)
67        for x in self.__nodes[node].adj:
68            if x not in seen:
69                seen.add(x)
70                visit_me.append(x)
71
72        while visit_me:
73            x = visit_me.pop()
74            assert x in seen
75            yield x
76
77            for y in self.__nodes[x].adj:
78                if y not in seen:
79                    seen.add(y)
80                    visit_me.append(y)
81
82class DiGraphNode:
83    def __init__(self):
84        # Set of adjacent of nodes.
85        self.adj = set()
86
87class ApiDependencies:
88    def __init__(self,
89                 registry_path = None,
90                 api_name = None):
91        """Load an API registry and generate extension dependencies
92
93        registry_path - relative filename of XML registry. If not specified,
94        uses the API default.
95
96        api_name - API name for which to generate dependencies. Only
97        extensions supported for that API are considered.
98        """
99
100        if registry_path is None:
101            registry_path = APIConventions().registry_path
102        if api_name is None:
103            api_name = APIConventions().xml_api_name
104
105        self.allExts = set()
106        self.khrExts = set()
107        self.graph = DiGraph()
108        self.extensions = {}
109        self.tree = etree.parse(registry_path)
110
111        # Loop over all supported extensions, creating a digraph of the
112        # extension dependencies in the 'requires' attribute, which is a
113        # comma-separated list of extension names. Also track lists of
114        # all extensions and all KHR extensions.
115        for elem in self.tree.findall('extensions/extension'):
116            name = elem.get('name')
117            supported = elem.get('supported')
118
119            # This works for the present form of the 'supported' attribute,
120            # which is a comma-separate list of XML API names
121            if api_name in supported.split(','):
122                self.allExts.add(name)
123
124                if 'KHR' in name:
125                    self.khrExts.add(name)
126
127                deps = elem.get('requires')
128                if deps:
129                    for dep in deps.split(','):
130                        self.graph.add_edge(name, dep)
131                else:
132                    self.graph.add_node(name)
133            else:
134                # Skip unsupported extensions
135                pass
136
137    def allExtensions(self):
138        """Returns a set of all extensions in the graph"""
139        return self.allExts
140
141    def khrExtensions(self):
142        """Returns a set of all KHR extensions in the graph"""
143        return self.khrExts
144
145    def children(self, extension):
146        """Returns a set of the dependencies of an extension.
147           Throws an exception if the extension is not in the graph."""
148
149        if extension not in self.allExts:
150            raise Exception(f'Extension {extension} not found in XML!')
151
152        return set(self.graph.descendants(extension))
153
154
155# Test script
156if __name__ == '__main__':
157    parser = argparse.ArgumentParser()
158
159    parser.add_argument('-registry', action='store',
160                        default=APIConventions().registry_path,
161                        help='Use specified registry file instead of ' + APIConventions().registry_path)
162    parser.add_argument('-loops', action='store',
163                        default=20, type=int,
164                        help='Number of timing loops to run')
165    parser.add_argument('-test', action='store',
166                        default=None,
167                        help='Specify extension to find dependencies of')
168
169    args = parser.parse_args()
170
171    import time
172    startTime = time.process_time()
173
174    for loop in range(args.loops):
175        deps = ApiDependencies(args.registry)
176
177    endTime = time.process_time()
178
179    deltaT = endTime - startTime
180    print('Total time = {} time/loop = {}'.format(deltaT, deltaT / args.loops))
181