1#!/usr/bin/env python3 2# 3# Copyright 2017-2022 The Khronos Group Inc. 4# SPDX-License-Identifier: Apache-2.0 5 6"""Generate a mapping of extension name -> all required extension names for 7 that extension, from dependencies in the API XML.""" 8 9import argparse 10import errno 11import xml.etree.ElementTree as etree 12from pathlib import Path 13 14from apiconventions import APIConventions 15 16class DiGraph: 17 """A directed graph. 18 19 The implementation and API mimic that of networkx.DiGraph in networkx-1.11. 20 networkx implements graphs as nested dicts; it uses dicts all the way 21 down, no lists. 22 23 Some major differences between this implementation and that of 24 networkx-1.11 are: 25 26 * This omits edge and node attribute data, because we never use them 27 yet they add additional code complexity. 28 29 * This returns iterator objects when possible instead of collection 30 objects, because it simplifies the implementation and should provide 31 better performance. 32 """ 33 34 def __init__(self): 35 self.__nodes = {} 36 37 def add_node(self, node): 38 if node not in self.__nodes: 39 self.__nodes[node] = DiGraphNode() 40 41 def add_edge(self, src, dest): 42 self.add_node(src) 43 self.add_node(dest) 44 self.__nodes[src].adj.add(dest) 45 46 def nodes(self): 47 """Iterate over the nodes in the graph.""" 48 return self.__nodes.keys() 49 50 def descendants(self, node): 51 """ 52 Iterate over the nodes reachable from the given start node, excluding 53 the start node itself. Each node in the graph is yielded at most once. 54 """ 55 56 # Implementation detail: Do a breadth-first traversal because it is 57 # easier than depth-first. 58 59 # All nodes seen during traversal. 60 seen = set() 61 62 # The stack of nodes that need visiting. 63 visit_me = [] 64 65 # Bootstrap the traversal. 66 seen.add(node) 67 for x in self.__nodes[node].adj: 68 if x not in seen: 69 seen.add(x) 70 visit_me.append(x) 71 72 while visit_me: 73 x = visit_me.pop() 74 assert x in seen 75 yield x 76 77 for y in self.__nodes[x].adj: 78 if y not in seen: 79 seen.add(y) 80 visit_me.append(y) 81 82class DiGraphNode: 83 def __init__(self): 84 # Set of adjacent of nodes. 85 self.adj = set() 86 87class ApiDependencies: 88 def __init__(self, 89 registry_path = None, 90 api_name = None): 91 """Load an API registry and generate extension dependencies 92 93 registry_path - relative filename of XML registry. If not specified, 94 uses the API default. 95 96 api_name - API name for which to generate dependencies. Only 97 extensions supported for that API are considered. 98 """ 99 100 if registry_path is None: 101 registry_path = APIConventions().registry_path 102 if api_name is None: 103 api_name = APIConventions().xml_api_name 104 105 self.allExts = set() 106 self.khrExts = set() 107 self.graph = DiGraph() 108 self.extensions = {} 109 self.tree = etree.parse(registry_path) 110 111 # Loop over all supported extensions, creating a digraph of the 112 # extension dependencies in the 'requires' attribute, which is a 113 # comma-separated list of extension names. Also track lists of 114 # all extensions and all KHR extensions. 115 for elem in self.tree.findall('extensions/extension'): 116 name = elem.get('name') 117 supported = elem.get('supported') 118 119 # This works for the present form of the 'supported' attribute, 120 # which is a comma-separate list of XML API names 121 if api_name in supported.split(','): 122 self.allExts.add(name) 123 124 if 'KHR' in name: 125 self.khrExts.add(name) 126 127 deps = elem.get('requires') 128 if deps: 129 for dep in deps.split(','): 130 self.graph.add_edge(name, dep) 131 else: 132 self.graph.add_node(name) 133 else: 134 # Skip unsupported extensions 135 pass 136 137 def allExtensions(self): 138 """Returns a set of all extensions in the graph""" 139 return self.allExts 140 141 def khrExtensions(self): 142 """Returns a set of all KHR extensions in the graph""" 143 return self.khrExts 144 145 def children(self, extension): 146 """Returns a set of the dependencies of an extension. 147 Throws an exception if the extension is not in the graph.""" 148 149 if extension not in self.allExts: 150 raise Exception(f'Extension {extension} not found in XML!') 151 152 return set(self.graph.descendants(extension)) 153 154 155# Test script 156if __name__ == '__main__': 157 parser = argparse.ArgumentParser() 158 159 parser.add_argument('-registry', action='store', 160 default=APIConventions().registry_path, 161 help='Use specified registry file instead of ' + APIConventions().registry_path) 162 parser.add_argument('-loops', action='store', 163 default=20, type=int, 164 help='Number of timing loops to run') 165 parser.add_argument('-test', action='store', 166 default=None, 167 help='Specify extension to find dependencies of') 168 169 args = parser.parse_args() 170 171 import time 172 startTime = time.process_time() 173 174 for loop in range(args.loops): 175 deps = ApiDependencies(args.registry) 176 177 endTime = time.process_time() 178 179 deltaT = endTime - startTime 180 print('Total time = {} time/loop = {}'.format(deltaT, deltaT / args.loops)) 181