1#!/usr/bin/env python3 2# 3# Copyright 2017-2023 The Khronos Group Inc. 4# SPDX-License-Identifier: Apache-2.0 5 6"""Generate a mapping of extension name -> all required extension names for 7 that extension, from dependencies in the API XML.""" 8 9import argparse 10import errno 11import xml.etree.ElementTree as etree 12from pathlib import Path 13 14from apiconventions import APIConventions 15from parse_dependency import dependencyNames 16 17class DiGraph: 18 """A directed graph. 19 20 The implementation and API mimic that of networkx.DiGraph in networkx-1.11. 21 networkx implements graphs as nested dicts; it uses dicts all the way 22 down, no lists. 23 24 Some major differences between this implementation and that of 25 networkx-1.11 are: 26 27 * This omits edge and node attribute data, because we never use them 28 yet they add additional code complexity. 29 30 * This returns iterator objects when possible instead of collection 31 objects, because it simplifies the implementation and should provide 32 better performance. 33 """ 34 35 def __init__(self): 36 self.__nodes = {} 37 38 def add_node(self, node): 39 if node not in self.__nodes: 40 self.__nodes[node] = DiGraphNode() 41 42 def add_edge(self, src, dest): 43 self.add_node(src) 44 self.add_node(dest) 45 self.__nodes[src].adj.add(dest) 46 47 def nodes(self): 48 """Iterate over the nodes in the graph.""" 49 return self.__nodes.keys() 50 51 def descendants(self, node): 52 """ 53 Iterate over the nodes reachable from the given start node, excluding 54 the start node itself. Each node in the graph is yielded at most once. 55 """ 56 57 # Implementation detail: Do a breadth-first traversal because it is 58 # easier than depth-first. 59 60 # All nodes seen during traversal. 61 seen = set() 62 63 # The stack of nodes that need visiting. 64 visit_me = [] 65 66 # Bootstrap the traversal. 67 seen.add(node) 68 for x in self.__nodes[node].adj: 69 if x not in seen: 70 seen.add(x) 71 visit_me.append(x) 72 73 while visit_me: 74 x = visit_me.pop() 75 assert x in seen 76 yield x 77 78 for y in self.__nodes[x].adj: 79 if y not in seen: 80 seen.add(y) 81 visit_me.append(y) 82 83class DiGraphNode: 84 def __init__(self): 85 # Set of adjacent of nodes. 86 self.adj = set() 87 88class ApiDependencies: 89 def __init__(self, 90 registry_path = None, 91 api_name = None): 92 """Load an API registry and generate extension dependencies 93 94 registry_path - relative filename of XML registry. If not specified, 95 uses the API default. 96 97 api_name - API name for which to generate dependencies. Only 98 extensions supported for that API are considered. 99 """ 100 101 conventions = APIConventions() 102 if registry_path is None: 103 registry_path = conventions.registry_path 104 if api_name is None: 105 api_name = conventions.xml_api_name 106 107 self.allExts = set() 108 self.khrExts = set() 109 self.ratifiedExts = set() 110 self.graph = DiGraph() 111 self.extensions = {} 112 self.tree = etree.parse(registry_path) 113 114 # Loop over all supported extensions, creating a digraph of the 115 # extension dependencies in the 'depends' attribute, which is a 116 # boolean expression of core version and extension names. 117 # A static dependency tree can be constructed only by treating all 118 # extension names in the expression as dependencies, even though 119 # that may not be true if it is of form (ext OR ext). 120 # For the purpose these dependencies are used for - generating 121 # specifications with required dependencies included automatically - 122 # this will suffice. 123 # Separately tracks lists of all extensions and all KHR extensions, 124 # which are common specification targets. 125 for elem in self.tree.findall('extensions/extension'): 126 name = elem.get('name') 127 supported = elem.get('supported') 128 ratified = elem.get('ratified', '') 129 130 if api_name in supported.split(','): 131 self.allExts.add(name) 132 133 if 'KHR' in name: 134 self.khrExts.add(name) 135 136 if api_name in ratified.split(','): 137 self.ratifiedExts.add(name) 138 139 self.graph.add_node(name) 140 141 depends = elem.get('depends') 142 if depends: 143 # Walk a list of the leaf nodes (version and extension 144 # names) in the boolean expression. 145 for dep in dependencyNames(depends): 146 # Filter out version names, which are explicitly 147 # specified when building a specification. 148 if not conventions.is_api_version_name(dep): 149 self.graph.add_edge(name, dep) 150 else: 151 # Skip unsupported extensions 152 pass 153 154 def allExtensions(self): 155 """Returns a set of all extensions in the graph""" 156 return self.allExts 157 158 def khrExtensions(self): 159 """Returns a set of all KHR extensions in the graph""" 160 return self.khrExts 161 162 def ratifiedExtensions(self): 163 """Returns a set of all ratified extensions in the graph""" 164 return self.ratifiedExts 165 166 def children(self, extension): 167 """Returns a set of the dependencies of an extension. 168 Throws an exception if the extension is not in the graph.""" 169 170 if extension not in self.allExts: 171 raise Exception(f'Extension {extension} not found in XML!') 172 173 return set(self.graph.descendants(extension)) 174 175 176# Test script 177if __name__ == '__main__': 178 parser = argparse.ArgumentParser() 179 180 parser.add_argument('-registry', action='store', 181 default=APIConventions().registry_path, 182 help='Use specified registry file instead of ' + APIConventions().registry_path) 183 parser.add_argument('-loops', action='store', 184 default=10, type=int, 185 help='Number of timing loops to run') 186 parser.add_argument('-test', action='store', 187 default=None, 188 help='Specify extension to find dependencies of') 189 190 args = parser.parse_args() 191 192 deps = ApiDependencies(args.registry) 193 print('KHR exts =', sorted(deps.khrExtensions())) 194 print('Ratified exts =', sorted(deps.ratifiedExtensions())) 195 196 import time 197 startTime = time.process_time() 198 199 for loop in range(args.loops): 200 deps = ApiDependencies(args.registry) 201 202 endTime = time.process_time() 203 204 deltaT = endTime - startTime 205 print('Total time = {} time/loop = {}'.format(deltaT, deltaT / args.loops)) 206